def input_x(self, text): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(config.DATA_PATH, "model", "uncased_L-24_H-1024_A-16", 'vocab.txt') self.token = tokenization.FullTokenizer(vocab_file=bert_vocab_file) # pattern = "[!]+" # text = re.sub(pattern, '', text) text = html.unescape(text) # pattern = re.compile(r'[0-9]|[%s]+' % punctuation) # tmp = re.sub(pattern, '', text) # tmp = tmp.split(' ') # new_tmp = [] # for word in tmp: # if word != '': # new_tmp.append(word) # text = ' '.join(new_tmp) word_ids, word_mask, word_segment_ids = \ convert_single_example_simple(max_seq_length=config.max_seq_length, tokenizer=self.token, text_a=text) return word_ids, word_mask, word_segment_ids
def input_x(self, TARGET, TEXT): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' #with open('topic.txt','r',encoding='utf-8') as f: #topic = json.load(f) TARGET = re.sub(pattern, "", TARGET) TEXT = re.sub(pattern, "", TEXT) #if TARGET == '深圳禁摩限电': #TARGET+= topic['深圳禁摩限电'] #elif TARGET == '春节放鞭炮': #TARGET+= topic['春节放鞭炮'] #elif TARGET == 'IphoneSE': #TARGET+= topic['IphoneSE'] #elif TARGET == '开放二胎': #TARGET+= topic['开放二胎'] #else: #TARGET+= topic['俄罗斯在叙利亚的反恐行动'] with open("pre_trained_root.txt", "r") as f: vocab_root = f.readline() f.close() if self.token is None: #bert_vocab_file = os.path.join(DATA_PATH, "model", "multi_cased_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=vocab_root) word_ids, word_mask, word_segment_ids = \ convert_single_example_simple(max_seq_length=180, tokenizer=self.token, text_a=TARGET, text_b=TEXT) return word_ids, word_mask, word_segment_ids, TARGET + TEXT
def input_x(self, sentence): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(config.DATA_PATH, "model", "uncased_L-24_H-1024_A-16", 'vocab.txt') self.token = tokenization.FullTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = convert_single_example_simple(max_seq_length=config.max_seq_length, tokenizer=self.token, text_a=sentence) return word_ids, word_mask, word_segment_ids
def input_x(self, news): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(config.DATA_PATH, "model", "chinese_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = \ convert_single_example_simple(max_seq_length=config.max_seq_length, tokenizer=self.token, text_a=news) return word_ids, word_mask, word_segment_ids
def input_x(self, sentence): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(DATA_PATH, "model", "multi_cased_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = convert_single_example_simple( max_seq_length=256, tokenizer=self.token, text_a=sentence) return word_ids, word_mask, word_segment_ids
def input_x(self, question1, question2): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(config.DATA_PATH, "model", "uncased_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.FullTokenizer(vocab_file=bert_vocab_file) question1 = question1.lower() if question1 is not np.nan else '' question2 = question2.lower() if question2 is not np.nan else '' word_ids, word_mask, word_segment_ids = convert_single_example_simple( max_seq_length=config.max_seq_length, tokenizer=self.token, text_a=question1, text_b=question2) return word_ids, word_mask, word_segment_ids
# 加载bert模型 tvars = tf.trainable_variables() (assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment) # 获取最后一层。 output_layer = model.get_sequence_output( ) # 这个获取每个token的output 输出[batch_size, seq_length, embedding_size] 如果做seq2seq 或者ner 用这个 output_layer_pooled = model.get_pooled_output() # 这个获取句子的output with tf.Session() as sess: sess.run(tf.global_variables_initializer()) query = u'今天去哪里吃' # word_ids, word_mask, word_segment_ids=get_inputdata(query) token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = convert_single_example_simple( max_seq_length=32, tokenizer=token, text_a=query, text_b='这里吃') print(len(word_ids)) print(word_mask) print(word_segment_ids) fd = { input_ids: [word_ids], input_mask: [word_mask], segment_ids: [word_segment_ids] } last, last2 = sess.run([output_layer, output_layer_pooled], feed_dict=fd) print('last shape:{}, last2 shape: {}'.format(last.shape, last2.shape)) pass