コード例 #1
0
    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        max_len = self.field_config.max_seq_len

        src_ids = []
        for text in batch_text:
            if self.field_config.need_convert:
                tokens = self.tokenizer.tokenize(text)
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                src_id = text.split(" ")

            # 加上截断策略
            if len(src_id) > self.field_config.max_seq_len:
                src_id = truncation_words(src_id, self.field_config.max_seq_len, self.field_config.truncation_type)
            src_ids.append(src_id)

        return_list = []
        padded_ids, mask_ids, batch_seq_lens = pad_batch_data(src_ids,
                                                              max_len=self.field_config.max_seq_len,
                                                              pad_idx=self.field_config.padding_id,
                                                              return_input_mask=True,
                                                              return_seq_lens=True,
                                                              paddle_version_code=self.paddle_version_code)
        return_list.append(padded_ids)
        return_list.append(mask_ids)
        return_list.append(batch_seq_lens)

        return return_list
コード例 #2
0
    def convert_texts_to_ids(self, batch_text):
        """convert a batch of input text instances to ids

        Args:
            batch_text (list of string): NULL

        Returns: TODO

        Raises: NULL

        """
        max_len = self.field_config.max_seq_len
        batch_fea_list = []
        name_block_len = []
        name_block_begin = []
        name_block_end = []
        for idx_batch, text in enumerate(batch_text):
            fea_str = text.split(' [SEP] ')
            fea_list = [[float(y) for y in x.split(' ')] for x in fea_str]

            # 加上截断策略
            if len(fea_list) > self.field_config.max_seq_len:
                logging.warn('input instance is to long: %s', text)
                fea_list = truncation_words(fea_list, self.field_config.max_seq_len, self.field_config.truncation_type)
            batch_fea_list.append(fea_list)

        return_list = []

        padded = [0] * self._feature_dim
        padded_ids = np.array([inst + list([padded] * (max_len - len(inst))) for inst in batch_fea_list])
        padded_ids = padded_ids.astype('float32').reshape([-1, max_len, self._feature_dim])

        return_list.append(padded_ids)

        return return_list
コード例 #3
0
def convert_text_to_id(text, field_config):
    """将一个明文样本转换成id
    :param text: 明文文本
    :param field_config : Field类型
    :return:
    """
    if not text:
        raise ValueError("text input is None")
    if not isinstance(field_config, Field):
        raise TypeError("field_config input is must be Field class")

    if field_config.need_convert:
        tokenizer = field_config.tokenizer
        tokens = tokenizer.tokenize(text)
        ids = tokenizer.convert_tokens_to_ids(tokens)
    else:
        ids = text.split(" ")

    # 加上截断策略
    if len(ids) > field_config.max_seq_len:
        ids = truncation_words(ids, field_config.max_seq_len,
                               field_config.truncation_type)

    return ids
コード例 #4
0
    def convert_texts_to_ids(self, batch_text):
        """convert a batch of input text instances to ids

        Args:
            batch_text (list of string): NULL

        Returns: TODO

        Raises: NULL

        """
        max_len = self.field_config.max_seq_len

        src_ids = []
        name_len = []
        name_block_pos = []
        name_block_len = []
        sep_id = self.tokenizer.covert_token_to_id("[SEP]")
        unk_id = self.tokenizer.covert_token_to_id("[UNK]")
        for idx_batch, text in enumerate(batch_text):
            if self.field_config.need_convert:
                tokens = self.tokenizer.tokenize(text)
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                src_id = text.split(" ")

            # 加上截断策略
            if len(src_id) > max_len - 1:
                logging.warn('input instance is to long(max %d): %s', max_len - 1, text)
                src_id = truncation_words(src_id, max_len - 1, self.field_config.truncation_type)
            if src_id[-1] != sep_id:
                src_id.append(sep_id)

            if src_id.count(sep_id) > self.max_item_len:
                raise ValueError("too many items. expacted max is %d, but got %d" % (self.max_item_len, src_id.count(sep_id)))
            src_ids.append(src_id)

            idx_begin = 0
            block_pos_tmp = []
            block_len_tmp = []
            for idx_end, tid in enumerate(src_id):
                if tid == sep_id:
                    supp_num = self.max_name_tokens - (idx_end - idx_begin)
                    block_pos_tmp.append(list(range(idx_begin, idx_end)) + [0] * supp_num)
                    block_len_tmp.append(idx_end - idx_begin)
                    idx_begin = idx_end + 1
            assert all([x > 0 for x in block_len_tmp]), 'token len should > 0: %s' % text
            name_len.append(len(block_pos_tmp))
            name_block_pos.append(block_pos_tmp)
            name_block_len.append(block_len_tmp)

        return_list = []
        padding_id = self.field_config.padding_id
        padded_ids, mask_ids, batch_seq_lens = pad_batch_data(
                                                    src_ids,
                                                    max_len=max_len,
                                                    pad_idx=padding_id,
                                                    return_input_mask=True,
                                                    return_seq_lens=True,
                                                    paddle_version_code=self.paddle_version_code)

        name_len = np.array(name_len).astype('int64').reshape(self.seq_len_shape)
        batch_name_pos = pad_batch_data(name_block_pos,
                shape=[-1, self.max_item_len, self.max_name_tokens], pad_idx=[0] * self.max_name_tokens)
        batch_name_block_len = pad_batch_data(name_block_len, shape=[-1, self.max_item_len])

        return_list.append(padded_ids)
        return_list.append(mask_ids)
        return_list.append(batch_seq_lens)
        return_list.append(name_len)
        return_list.append(batch_name_pos)
        return_list.append(batch_name_block_len)

        return return_list