def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) choices = input_example.meta['candidates'] question_idx = input_example.meta['question_idx'] input_features.meta['candidate_token_ids'] = [] input_features.meta['candidate_labels'] = [] input_features.meta['question_idx'] = question_idx self.original_choices[question_idx] = [] for idx, choice_text in enumerate(choices): choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) choice_label = 1 if choice_text in input_example.meta[ 'answers'] else 0 mask_end = mask_start + len(choice_token_ids) candidate_token_ids = [-100] * len(input_features.input_ids) candidate_token_ids[mask_start:mask_end] = choice_token_ids input_features.meta['candidate_token_ids'].append( candidate_token_ids) input_features.meta['candidate_labels'].append(choice_label) self.original_choices[question_idx].append(choice_text)
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id) for choice in ['choice1', 'choice2']: choice_text = input_example.meta[choice] choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) mask_end = mask_start + len(choice_token_ids) input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids) input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) if 'choices' in input_example.meta: choices = [choice for choice in input_example.meta['choices']] else: label_list = self.wrapper.config.label_list choices = [ self.wrapper.preprocessor.pvp.verbalize(label)[0] for label in label_list ] input_features.meta['choice_token_ids'] = [] for idx, choice_text in enumerate(choices): choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) mask_end = mask_start + len(choice_token_ids) candidate_token_ids = [-100] * len(input_features.input_ids) candidate_token_ids[mask_start:mask_end] = choice_token_ids input_features.meta['choice_token_ids'].append(candidate_token_ids)
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id) num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id) mask_end = mask_start + num_masks target = input_example.meta['span1_text'] input_features.meta['target'] = target target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False) input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids) # we also predict <pad> tokens at the missing positions target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids)) input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, **kwargs) -> InputFeatures: # 获得PVP(模板句子+label mapping) input_ids, token_type_ids, block_flag = self.pvp.encode(example) attention_mask = [1] * len(input_ids) padding_length = self.wrapper.config.max_seq_length - len(input_ids) if padding_length < 0: raise ValueError( f"Maximum sequence length is too small, got {len(input_ids)} input ids" ) input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) # wordid序列+padding attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) block_flag = block_flag + ([0] * padding_length) assert len(input_ids) == self.wrapper.config.max_seq_length assert len(attention_mask) == self.wrapper.config.max_seq_length assert len(token_type_ids) == self.wrapper.config.max_seq_length assert len(block_flag) == self.wrapper.config.max_seq_length example_label = example.label example_task = example.task # add by wjn 表示当前样本所属的task # add by wjn 只有当数字型的label(0,1),可能真实标签是字符串('0', '1'),因此需要进行转换判断 if example_label not in self.label_map.keys(): if type(example_label) == int: example_label = str(example_label) elif type(example_label) == str: example_label = int(example_label) label = self.label_map[ example_label] if example.label is not None else -100 task = task_to_id[example_task] # add by wjn 表示当前task对应group内的编号 logits = example.logits if example.logits else [-1] if labelled: mlm_labels = self.pvp.get_mask_positions(input_ids) else: mlm_labels = [-1] * self.wrapper.config.max_seq_length return InputFeatures( guid=int(example.guid.split('-')[1]), input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, task=task, label=label, # [0,1,..] mlm_labels=mlm_labels, # [-1, -1, .., -1, 1, -1, -1, ...] logits=logits, idx=example.idx, block_flag=block_flag)
def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures: inputs = self.wrapper.task_helper.get_sequence_classifier_inputs( example) if self.wrapper.task_helper else None if inputs is None: inputs = self.wrapper.tokenizer.encode_plus( example.text_a if example.text_a else None, example.text_b if example.text_b else None, add_special_tokens=True, max_length=self.wrapper.config.max_seq_length, truncation=True, ) input_ids, token_type_ids = inputs["input_ids"], inputs.get( "token_type_ids") attention_mask = [1] * len(input_ids) padding_length = self.wrapper.config.max_seq_length - len(input_ids) input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) attention_mask = attention_mask + ([0] * padding_length) if not token_type_ids: token_type_ids = [0] * self.wrapper.config.max_seq_length else: token_type_ids = token_type_ids + ([0] * padding_length) mlm_labels = [-1] * len(input_ids) assert len(input_ids) == self.wrapper.config.max_seq_length assert len(attention_mask) == self.wrapper.config.max_seq_length assert len(token_type_ids) == self.wrapper.config.max_seq_length label = self.label_map[ example.label] if example.label is not None else -100 logits = example.logits if example.logits else [-1] return InputFeatures( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx, )
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, **kwargs) -> InputFeatures: input_ids, token_type_ids, block_flag = self.pvp.encode(example) attention_mask = [1] * len(input_ids) padding_length = self.wrapper.config.max_seq_length - len(input_ids) if padding_length < 0: raise ValueError( f"Maximum sequence length is too small, got {len(input_ids)} input ids" ) input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) block_flag = block_flag + ([0] * padding_length) assert len(input_ids) == self.wrapper.config.max_seq_length assert len(attention_mask) == self.wrapper.config.max_seq_length assert len(token_type_ids) == self.wrapper.config.max_seq_length assert len(block_flag) == self.wrapper.config.max_seq_length label = self.label_map[ example.label] if example.label is not None else -100 logits = example.logits if example.logits else [-1] if labelled: mlm_labels = self.pvp.get_mask_positions(input_ids) else: mlm_labels = [-1] * self.wrapper.config.max_seq_length return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx, block_flag=block_flag)
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: input_features.meta['question_idx'] = input_example.meta[ 'question_idx']
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, **kwargs) -> InputFeatures: if priming: input_ids, token_type_ids = self.pvp.encode(example, priming=True) priming_data = example.meta[ 'priming_data'] # type: List[InputExample] priming_input_ids = [] for priming_example in priming_data: pe_input_ids, _ = self.pvp.encode(priming_example, priming=True, labeled=True) priming_input_ids += pe_input_ids input_ids = priming_input_ids + input_ids token_type_ids = self.wrapper.tokenizer.create_token_type_ids_from_sequences( input_ids) input_ids = self.wrapper.tokenizer.build_inputs_with_special_tokens( input_ids) else: input_ids, token_type_ids = self.pvp.encode(example) attention_mask = [1] * len(input_ids) length = len(input_ids) padding_length = self.wrapper.config.max_seq_length - len(input_ids) if padding_length < 0: raise ValueError( f"Maximum sequence length is too small, got {len(input_ids)} input ids" ) input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) assert len(input_ids) == self.wrapper.config.max_seq_length assert len(attention_mask) == self.wrapper.config.max_seq_length assert len(token_type_ids) == self.wrapper.config.max_seq_length label = self.label_map[ example.label] if example.label is not None else -100 logits = example.logits if example.logits else [-1] if labelled: mlm_labels = self.pvp.get_mask_positions(input_ids) if self.wrapper.config.model_type == 'gpt2': # shift labels to the left by one mlm_labels.append(mlm_labels.pop(0)) else: mlm_labels = [-1] * self.wrapper.config.max_seq_length return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx, length=length)