示例#1
0
    def add_special_input_features(self, input_example: InputExample,
                                   input_features: InputFeatures) -> None:
        mask_start = input_features.input_ids.index(
            self.wrapper.tokenizer.mask_token_id)

        choices = input_example.meta['candidates']
        question_idx = input_example.meta['question_idx']

        input_features.meta['candidate_token_ids'] = []
        input_features.meta['candidate_labels'] = []
        input_features.meta['question_idx'] = question_idx

        self.original_choices[question_idx] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text,
                                                     self.wrapper.tokenizer,
                                                     force_single_token=False)
            choice_label = 1 if choice_text in input_example.meta[
                'answers'] else 0

            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids

            input_features.meta['candidate_token_ids'].append(
                candidate_token_ids)
            input_features.meta['candidate_labels'].append(choice_label)
            self.original_choices[question_idx].append(choice_text)
示例#2
0
    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)

        for choice in ['choice1', 'choice2']:
            choice_text = input_example.meta[choice]
            choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids)
            input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids
示例#3
0
    def add_special_input_features(self, input_example: InputExample,
                                   input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(
            self.wrapper.tokenizer.mask_token_id)

        if 'choices' in input_example.meta:
            choices = [choice for choice in input_example.meta['choices']]
        else:
            label_list = self.wrapper.config.label_list
            choices = [
                self.wrapper.preprocessor.pvp.verbalize(label)[0]
                for label in label_list
            ]

        input_features.meta['choice_token_ids'] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text,
                                                     self.wrapper.tokenizer,
                                                     force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids
            input_features.meta['choice_token_ids'].append(candidate_token_ids)
示例#4
0
    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
        num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id)
        mask_end = mask_start + num_masks

        target = input_example.meta['span1_text']
        input_features.meta['target'] = target
        target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)
        input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids)

        # we also predict <pad> tokens at the missing positions
        target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids))
        input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids
示例#5
0
    def get_input_features(self,
                           example: InputExample,
                           labelled: bool,
                           priming: bool = False,
                           **kwargs) -> InputFeatures:
        # 获得PVP(模板句子+label mapping)
        input_ids, token_type_ids, block_flag = self.pvp.encode(example)

        attention_mask = [1] * len(input_ids)
        padding_length = self.wrapper.config.max_seq_length - len(input_ids)

        if padding_length < 0:
            raise ValueError(
                f"Maximum sequence length is too small, got {len(input_ids)} input ids"
            )

        input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] *
                                 padding_length)  # wordid序列+padding
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        block_flag = block_flag + ([0] * padding_length)

        assert len(input_ids) == self.wrapper.config.max_seq_length
        assert len(attention_mask) == self.wrapper.config.max_seq_length
        assert len(token_type_ids) == self.wrapper.config.max_seq_length
        assert len(block_flag) == self.wrapper.config.max_seq_length

        example_label = example.label
        example_task = example.task  # add by wjn 表示当前样本所属的task
        # add by wjn 只有当数字型的label(0,1),可能真实标签是字符串('0', '1'),因此需要进行转换判断
        if example_label not in self.label_map.keys():
            if type(example_label) == int:
                example_label = str(example_label)
            elif type(example_label) == str:
                example_label = int(example_label)

        label = self.label_map[
            example_label] if example.label is not None else -100
        task = task_to_id[example_task]  # add by wjn 表示当前task对应group内的编号
        logits = example.logits if example.logits else [-1]

        if labelled:
            mlm_labels = self.pvp.get_mask_positions(input_ids)
        else:
            mlm_labels = [-1] * self.wrapper.config.max_seq_length

        return InputFeatures(
            guid=int(example.guid.split('-')[1]),
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            task=task,
            label=label,  # [0,1,..]
            mlm_labels=mlm_labels,  # [-1, -1, .., -1, 1, -1, -1, ...]
            logits=logits,
            idx=example.idx,
            block_flag=block_flag)
示例#6
0
    def get_input_features(self, example: InputExample,
                           **kwargs) -> InputFeatures:
        inputs = self.wrapper.task_helper.get_sequence_classifier_inputs(
            example) if self.wrapper.task_helper else None
        if inputs is None:
            inputs = self.wrapper.tokenizer.encode_plus(
                example.text_a if example.text_a else None,
                example.text_b if example.text_b else None,
                add_special_tokens=True,
                max_length=self.wrapper.config.max_seq_length,
                truncation=True,
            )
        input_ids, token_type_ids = inputs["input_ids"], inputs.get(
            "token_type_ids")

        attention_mask = [1] * len(input_ids)
        padding_length = self.wrapper.config.max_seq_length - len(input_ids)

        input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] *
                                 padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        if not token_type_ids:
            token_type_ids = [0] * self.wrapper.config.max_seq_length
        else:
            token_type_ids = token_type_ids + ([0] * padding_length)
        mlm_labels = [-1] * len(input_ids)

        assert len(input_ids) == self.wrapper.config.max_seq_length
        assert len(attention_mask) == self.wrapper.config.max_seq_length
        assert len(token_type_ids) == self.wrapper.config.max_seq_length

        label = self.label_map[
            example.label] if example.label is not None else -100
        logits = example.logits if example.logits else [-1]

        return InputFeatures(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            label=label,
            mlm_labels=mlm_labels,
            logits=logits,
            idx=example.idx,
        )
示例#7
0
    def get_input_features(self,
                           example: InputExample,
                           labelled: bool,
                           priming: bool = False,
                           **kwargs) -> InputFeatures:

        input_ids, token_type_ids, block_flag = self.pvp.encode(example)

        attention_mask = [1] * len(input_ids)
        padding_length = self.wrapper.config.max_seq_length - len(input_ids)

        if padding_length < 0:
            raise ValueError(
                f"Maximum sequence length is too small, got {len(input_ids)} input ids"
            )

        input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] *
                                 padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        block_flag = block_flag + ([0] * padding_length)

        assert len(input_ids) == self.wrapper.config.max_seq_length
        assert len(attention_mask) == self.wrapper.config.max_seq_length
        assert len(token_type_ids) == self.wrapper.config.max_seq_length
        assert len(block_flag) == self.wrapper.config.max_seq_length

        label = self.label_map[
            example.label] if example.label is not None else -100
        logits = example.logits if example.logits else [-1]

        if labelled:
            mlm_labels = self.pvp.get_mask_positions(input_ids)
        else:
            mlm_labels = [-1] * self.wrapper.config.max_seq_length

        return InputFeatures(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids,
                             label=label,
                             mlm_labels=mlm_labels,
                             logits=logits,
                             idx=example.idx,
                             block_flag=block_flag)
示例#8
0
 def add_special_input_features(self, input_example: InputExample,
                                input_features: InputFeatures) -> None:
     input_features.meta['question_idx'] = input_example.meta[
         'question_idx']
示例#9
0
    def get_input_features(self,
                           example: InputExample,
                           labelled: bool,
                           priming: bool = False,
                           **kwargs) -> InputFeatures:

        if priming:
            input_ids, token_type_ids = self.pvp.encode(example, priming=True)
            priming_data = example.meta[
                'priming_data']  # type: List[InputExample]

            priming_input_ids = []
            for priming_example in priming_data:
                pe_input_ids, _ = self.pvp.encode(priming_example,
                                                  priming=True,
                                                  labeled=True)
                priming_input_ids += pe_input_ids

            input_ids = priming_input_ids + input_ids
            token_type_ids = self.wrapper.tokenizer.create_token_type_ids_from_sequences(
                input_ids)
            input_ids = self.wrapper.tokenizer.build_inputs_with_special_tokens(
                input_ids)
        else:
            input_ids, token_type_ids = self.pvp.encode(example)

        attention_mask = [1] * len(input_ids)
        length = len(input_ids)
        padding_length = self.wrapper.config.max_seq_length - len(input_ids)

        if padding_length < 0:
            raise ValueError(
                f"Maximum sequence length is too small, got {len(input_ids)} input ids"
            )

        input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] *
                                 padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)

        assert len(input_ids) == self.wrapper.config.max_seq_length
        assert len(attention_mask) == self.wrapper.config.max_seq_length
        assert len(token_type_ids) == self.wrapper.config.max_seq_length

        label = self.label_map[
            example.label] if example.label is not None else -100
        logits = example.logits if example.logits else [-1]

        if labelled:
            mlm_labels = self.pvp.get_mask_positions(input_ids)
            if self.wrapper.config.model_type == 'gpt2':
                # shift labels to the left by one
                mlm_labels.append(mlm_labels.pop(0))
        else:
            mlm_labels = [-1] * self.wrapper.config.max_seq_length

        return InputFeatures(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids,
                             label=label,
                             mlm_labels=mlm_labels,
                             logits=logits,
                             idx=example.idx,
                             length=length)