示例#1
0
    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        for text in batch_text:
            if self.field_config.need_convert:
                tokens = self.tokenizer.tokenize(text)
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                src_id = text.split(" ")

            # 加上截断策略
            if len(src_id) > self.field_config.max_seq_len:
                src_id = truncation_words(src_id, self.field_config.max_seq_len, self.field_config.truncation_type)
            src_ids.append(src_id)

        return_list = []
        padded_ids, batch_seq_lens = pad_batch_data(src_ids,
                                                              pad_idx=self.field_config.padding_id,
                                                              return_input_mask=False,
                                                              return_seq_lens=True)
        return_list.append(padded_ids)
        return_list.append(batch_seq_lens)

        return return_list
示例#2
0
    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        for text in batch_text:
            if self.field_config.need_convert:
                tokens = self.tokenizer.tokenize(text)
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                if isinstance(text, str):
                    src_id = text.split(" ")
                src_id = [int(i) for i in text]

            # 加上截断策略
            if len(src_id) > self.field_config.max_seq_len - 2:
                src_id = truncation_words(src_id, self.field_config.max_seq_len - 2, self.field_config.truncation_type)
            unk_id = self.tokenizer.vocabulary.vocab_dict[self.field_config.tokenizer_info["unk_token"]] 
            src_id.insert(0, unk_id)
            src_id.append(unk_id)
            src_ids.append(src_id)
        
        return_list = []
        padded_ids, mask_ids, batch_seq_lens = pad_batch_data(src_ids,
                                                              pad_idx=self.field_config.padding_id,
                                                              return_input_mask=True,
                                                              return_seq_lens=True)
        return_list.append(padded_ids)
        return_list.append(mask_ids)
        return_list.append(batch_seq_lens)

        return return_list
示例#3
0
    def convert_texts_to_ids(self, batch_text):
        """ 明文序列化
        :return: id_list
        """
        src_ids = []
        for text in batch_text:
            if self.tokenizer and self.field_config.need_convert:
                tokens = self.tokenizer.tokenize(text)
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                src_id = text.split(" ")

            # 加上截断策略
            if len(src_id) > self.field_config.max_seq_len:
                src_id = truncation_words(src_id,
                                          self.field_config.max_seq_len,
                                          self.field_config.truncation_type)
            src_ids.append(src_id)

        data_type = "int64" if self.field_config.data_type == DataShape.INT else "float32"

        padded_ids, batch_seq_lens = pad_batch_data(
            src_ids,
            insts_data_type=data_type,
            pad_idx=self.field_config.padding_id,
            return_input_mask=False,
            return_seq_lens=True)
        return_list = []
        return_list.append(padded_ids)
        return_list.append(batch_seq_lens)
        return return_list
示例#4
0
    def serialize_batch_records(self, batch_records):
        """pad batch records"""
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_text_type_ids = [
            record.text_type_ids for record in batch_records
        ]
        batch_position_ids = [record.position_ids for record in batch_records]
        batch_task_ids = [record.task_ids for record in batch_records]
        if "predict" not in self.name:
            batch_labels = [record.label_id for record in batch_records]
            if self.is_classify:
                batch_labels = np.array(batch_labels).astype("int64").reshape(
                    [-1, 1])
            elif self.is_regression:
                batch_labels = np.array(batch_labels).astype(
                    "float32").reshape([-1, 1])
        else:
            if self.is_classify:
                batch_labels = np.array([]).astype("int64").reshape([-1, 1])
            elif self.is_regression:
                batch_labels = np.array([]).astype("float32").reshape([-1, 1])

        if batch_records[0].qid:
            batch_qids = [record.qid for record in batch_records]
            batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
        else:
            batch_qids = np.array([]).astype("int64").reshape([-1, 1])

        # padding
        padded_token_ids, input_mask = pad_batch_data(batch_token_ids,
                                                      pad_idx=self.pad_id,
                                                      return_input_mask=True)
        padded_text_type_ids = pad_batch_data(batch_text_type_ids,
                                              pad_idx=self.pad_id)
        padded_position_ids = pad_batch_data(batch_position_ids,
                                             pad_idx=self.pad_id)
        padded_task_ids = pad_batch_data(batch_task_ids, pad_idx=0)

        return_list = [
            padded_token_ids, padded_text_type_ids, padded_position_ids,
            padded_task_ids, input_mask, batch_labels, batch_qids
        ]

        return return_list
示例#5
0
    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        position_ids = []
        task_ids = []
        sentence_ids = []
        for text in batch_text:
            if self.field_config.need_convert:
                tokens_text = self.tokenizer.tokenize(text)
                # 加上截断策略
                if len(tokens_text) > self.field_config.max_seq_len - 2:
                    tokens_text = truncation_words(
                        tokens_text, self.field_config.max_seq_len - 2,
                        self.field_config.truncation_type)
                tokens = []
                tokens.append("[CLS]")
                for token in tokens_text:
                    tokens.append(token)
                tokens.append("[SEP]")
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                if isinstance(text, str):
                    src_id = text.split(" ")
                src_id = [int(i) for i in text]
                if len(src_id) > self.field_config.max_seq_len - 2:
                    src_id = truncation_words(
                        src_id, self.field_config.max_seq_len - 2,
                        self.field_config.truncation_type)
                src_id.insert(0, self.tokenizer.covert_token_to_id("[CLS]"))
                src_id.append(self.tokenizer.covert_token_to_id("[SEP]"))

            src_ids.append(src_id)
            pos_id = list(range(len(src_id)))
            task_id = [0] * len(src_id)
            sentence_id = [0] * len(src_id)
            position_ids.append(pos_id)
            task_ids.append(task_id)
            sentence_ids.append(sentence_id)

        return_list = []
        padded_ids, input_mask, batch_seq_lens = pad_batch_data(
            src_ids,
            pad_idx=self.field_config.padding_id,
            return_input_mask=True,
            return_seq_lens=True)
        sent_ids_batch = pad_batch_data(sentence_ids,
                                        pad_idx=self.field_config.padding_id)
        pos_ids_batch = pad_batch_data(position_ids,
                                       pad_idx=self.field_config.padding_id)
        task_ids_batch = pad_batch_data(task_ids,
                                        pad_idx=self.field_config.padding_id)

        return_list.append(padded_ids)  # append src_ids
        return_list.append(sent_ids_batch)  # append sent_ids
        return_list.append(pos_ids_batch)  # append pos_ids
        return_list.append(input_mask)  # append mask
        return_list.append(task_ids_batch)  # append task_ids
        return_list.append(batch_seq_lens)  # append seq_lens

        return return_list
示例#6
0
    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        position_ids = []
        task_ids = []
        sentence_ids = []
        batch_text_a, batch_text_b = batch_text
        assert len(batch_text_a) == len(batch_text_b)

        for text_a, text_b in zip(batch_text_a, batch_text_b):
            if self.field_config.need_convert:
                tokens_text_a = self.tokenizer.tokenize(text_a)
                tokens_text_b = self.tokenizer.tokenize(text_b)
                # 加上截断策略
                truncate_seq_pair(tokens_text_a, tokens_text_b,
                                  self.field_config.max_seq_len - 3)
                text_a_len, text_b_len = len(tokens_text_a), len(tokens_text_b)
                tokens_text = tokens_text_a + ["[SEP]"] + tokens_text_b
                tokens = []
                tokens.append("[CLS]")
                for token in tokens_text:
                    tokens.append(token)
                tokens.append("[SEP]")
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                src_a_id = text_a.split(" ")
                src_b_id = text_b.split(" ")
                truncate_seq_pair(src_a_id, src_b_id,
                                  self.field_config.max_seq_len - 3)
                text_a_len, text_b_len = len(src_a_id), len(src_b_id)
                src_id = src_a_id + ["[SEP]"] + src_b_id

                if len(src_id) > self.field_config.max_seq_len - 2:
                    src_id = truncation_words(
                        src_id, self.field_config.max_seq_len - 2,
                        self.field_config.truncation_type)
                    src_id.insert(0,
                                  self.tokenizer.covert_token_to_id("[CLS]"))
                    src_id.append(self.tokenizer.covert_token_to_id("[SEP]"))

            src_ids.append(src_id)
            pos_id = list(range(len(src_id)))
            task_id = [0] * len(src_id)
            sentence_id = [0] * (text_a_len + 2) + [1] * (text_b_len + 1)
            position_ids.append(pos_id)
            task_ids.append(task_id)
            sentence_ids.append(sentence_id)

        return_list = []

        padded_ids, input_mask, batch_seq_lens = pad_batch_data(
            src_ids,
            pad_idx=self.field_config.padding_id,
            return_input_mask=True,
            return_seq_lens=True)
        sent_ids_batch = pad_batch_data(sentence_ids,
                                        pad_idx=self.field_config.padding_id)
        pos_ids_batch = pad_batch_data(position_ids,
                                       pad_idx=self.field_config.padding_id)
        task_ids_batch = pad_batch_data(task_ids,
                                        pad_idx=self.field_config.padding_id)

        return_list.append(padded_ids)  # append src_ids
        return_list.append(sent_ids_batch)  # append sent_ids
        return_list.append(pos_ids_batch)  # append pos_ids
        return_list.append(input_mask)  # append mask
        return_list.append(task_ids_batch)  # append task_ids
        return_list.append(batch_seq_lens)  # append seq_lens

        return return_list