def convert_texts_to_ids(self, batch_text): """将一个batch的明文text转成id :param batch_text: :return: """ max_len = self.field_config.max_seq_len src_ids = [] for text in batch_text: if self.field_config.need_convert: tokens = self.tokenizer.tokenize(text) src_id = self.tokenizer.convert_tokens_to_ids(tokens) else: src_id = text.split(" ") # 加上截断策略 if len(src_id) > self.field_config.max_seq_len: src_id = truncation_words(src_id, self.field_config.max_seq_len, self.field_config.truncation_type) src_ids.append(src_id) return_list = [] padded_ids, mask_ids, batch_seq_lens = pad_batch_data(src_ids, max_len=self.field_config.max_seq_len, pad_idx=self.field_config.padding_id, return_input_mask=True, return_seq_lens=True, paddle_version_code=self.paddle_version_code) return_list.append(padded_ids) return_list.append(mask_ids) return_list.append(batch_seq_lens) return return_list
def convert_texts_to_ids(self, batch_text): """convert a batch of input text instances to ids Args: batch_text (list of string): NULL Returns: TODO Raises: NULL """ max_len = self.field_config.max_seq_len batch_fea_list = [] name_block_len = [] name_block_begin = [] name_block_end = [] for idx_batch, text in enumerate(batch_text): fea_str = text.split(' [SEP] ') fea_list = [[float(y) for y in x.split(' ')] for x in fea_str] # 加上截断策略 if len(fea_list) > self.field_config.max_seq_len: logging.warn('input instance is to long: %s', text) fea_list = truncation_words(fea_list, self.field_config.max_seq_len, self.field_config.truncation_type) batch_fea_list.append(fea_list) return_list = [] padded = [0] * self._feature_dim padded_ids = np.array([inst + list([padded] * (max_len - len(inst))) for inst in batch_fea_list]) padded_ids = padded_ids.astype('float32').reshape([-1, max_len, self._feature_dim]) return_list.append(padded_ids) return return_list
def convert_text_to_id(text, field_config): """将一个明文样本转换成id :param text: 明文文本 :param field_config : Field类型 :return: """ if not text: raise ValueError("text input is None") if not isinstance(field_config, Field): raise TypeError("field_config input is must be Field class") if field_config.need_convert: tokenizer = field_config.tokenizer tokens = tokenizer.tokenize(text) ids = tokenizer.convert_tokens_to_ids(tokens) else: ids = text.split(" ") # 加上截断策略 if len(ids) > field_config.max_seq_len: ids = truncation_words(ids, field_config.max_seq_len, field_config.truncation_type) return ids
def convert_texts_to_ids(self, batch_text): """convert a batch of input text instances to ids Args: batch_text (list of string): NULL Returns: TODO Raises: NULL """ max_len = self.field_config.max_seq_len src_ids = [] name_len = [] name_block_pos = [] name_block_len = [] sep_id = self.tokenizer.covert_token_to_id("[SEP]") unk_id = self.tokenizer.covert_token_to_id("[UNK]") for idx_batch, text in enumerate(batch_text): if self.field_config.need_convert: tokens = self.tokenizer.tokenize(text) src_id = self.tokenizer.convert_tokens_to_ids(tokens) else: src_id = text.split(" ") # 加上截断策略 if len(src_id) > max_len - 1: logging.warn('input instance is to long(max %d): %s', max_len - 1, text) src_id = truncation_words(src_id, max_len - 1, self.field_config.truncation_type) if src_id[-1] != sep_id: src_id.append(sep_id) if src_id.count(sep_id) > self.max_item_len: raise ValueError("too many items. expacted max is %d, but got %d" % (self.max_item_len, src_id.count(sep_id))) src_ids.append(src_id) idx_begin = 0 block_pos_tmp = [] block_len_tmp = [] for idx_end, tid in enumerate(src_id): if tid == sep_id: supp_num = self.max_name_tokens - (idx_end - idx_begin) block_pos_tmp.append(list(range(idx_begin, idx_end)) + [0] * supp_num) block_len_tmp.append(idx_end - idx_begin) idx_begin = idx_end + 1 assert all([x > 0 for x in block_len_tmp]), 'token len should > 0: %s' % text name_len.append(len(block_pos_tmp)) name_block_pos.append(block_pos_tmp) name_block_len.append(block_len_tmp) return_list = [] padding_id = self.field_config.padding_id padded_ids, mask_ids, batch_seq_lens = pad_batch_data( src_ids, max_len=max_len, pad_idx=padding_id, return_input_mask=True, return_seq_lens=True, paddle_version_code=self.paddle_version_code) name_len = np.array(name_len).astype('int64').reshape(self.seq_len_shape) batch_name_pos = pad_batch_data(name_block_pos, shape=[-1, self.max_item_len, self.max_name_tokens], pad_idx=[0] * self.max_name_tokens) batch_name_block_len = pad_batch_data(name_block_len, shape=[-1, self.max_item_len]) return_list.append(padded_ids) return_list.append(mask_ids) return_list.append(batch_seq_lens) return_list.append(name_len) return_list.append(batch_name_pos) return_list.append(batch_name_block_len) return return_list