def get_inputdata(query): token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) split_tokens = token.tokenize(query) word_ids = token.convert_tokens_to_ids(split_tokens) word_mask = [1] * len(word_ids) word_segment_ids = [0] * len(word_ids) return word_ids, word_mask, word_segment_ids
def input_x(self, TARGET, TEXT): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' #with open('topic.txt','r',encoding='utf-8') as f: #topic = json.load(f) TARGET = re.sub(pattern, "", TARGET) TEXT = re.sub(pattern, "", TEXT) #if TARGET == '深圳禁摩限电': #TARGET+= topic['深圳禁摩限电'] #elif TARGET == '春节放鞭炮': #TARGET+= topic['春节放鞭炮'] #elif TARGET == 'IphoneSE': #TARGET+= topic['IphoneSE'] #elif TARGET == '开放二胎': #TARGET+= topic['开放二胎'] #else: #TARGET+= topic['俄罗斯在叙利亚的反恐行动'] with open("pre_trained_root.txt", "r") as f: vocab_root = f.readline() f.close() if self.token is None: #bert_vocab_file = os.path.join(DATA_PATH, "model", "multi_cased_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=vocab_root) word_ids, word_mask, word_segment_ids = \ convert_single_example_simple(max_seq_length=180, tokenizer=self.token, text_a=TARGET, text_b=TEXT) return word_ids, word_mask, word_segment_ids, TARGET + TEXT
def input_x(self, news): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(config.DATA_PATH, "model", "chinese_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = \ convert_single_example_simple(max_seq_length=config.max_seq_length, tokenizer=self.token, text_a=news) return word_ids, word_mask, word_segment_ids
def input_x(self, sentence): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 ''' if self.token is None: bert_vocab_file = os.path.join(DATA_PATH, "model", "multi_cased_L-12_H-768_A-12", 'vocab.txt') self.token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = convert_single_example_simple( max_seq_length=256, tokenizer=self.token, text_a=sentence) return word_ids, word_mask, word_segment_ids
def input_x(self, TARGET, TEXT): ''' 参数为csv中作为输入x的一条数据,该方法会被Dataset多次调用 multi_cased_L-12_H-768_A-12 bert模型的一个种类 chinese_roberta_wwm_large_ext_L-24_H-1024_A-16 chinese_L-12_H-768_A-12 ''' # print(STANCE) if self.token is None: # bert_vocab_file = os.path.join(sys.path[0],'chinese_roberta_wwm_large_ext_L-24_H-1024_A-16', 'vocab.txt') bert_vocab_file = os.path.join(sys.path[0],'chinese_L-12_H-768_A-12','chinese_L-12_H-768_A-12', 'vocab.txt') # 初始化了,只指定了bert_vocab_file self.token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) # self.token = tokenization.FullTokenizer(vocab_file=bert_vocab_file) # TEXT = data_clean(TEXT) #max_seq_length=256 return ""
# 加载bert模型 tvars = tf.trainable_variables() (assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment) # 获取最后一层。 output_layer = model.get_sequence_output( ) # 这个获取每个token的output 输出[batch_size, seq_length, embedding_size] 如果做seq2seq 或者ner 用这个 output_layer_pooled = model.get_pooled_output() # 这个获取句子的output with tf.Session() as sess: sess.run(tf.global_variables_initializer()) query = u'今天去哪里吃' # word_ids, word_mask, word_segment_ids=get_inputdata(query) token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) word_ids, word_mask, word_segment_ids = convert_single_example_simple( max_seq_length=32, tokenizer=token, text_a=query, text_b='这里吃') print(len(word_ids)) print(word_mask) print(word_segment_ids) fd = { input_ids: [word_ids], input_mask: [word_mask], segment_ids: [word_segment_ids] } last, last2 = sess.run([output_layer, output_layer_pooled], feed_dict=fd) print('last shape:{}, last2 shape: {}'.format(last.shape, last2.shape)) pass