def __init__(self, num_symbols, num_embed_units, num_units, num_layers, embed, entity_embed=None, num_entities=0, num_trans_units=100, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=60, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch self.entities = tf.placeholder(tf.string, (None, None), 'entities') # batch self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch self.triples = tf.placeholder(tf.string, (None, None, 3), 'triples') # batch self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch self.match_triples = tf.placeholder(tf.int32, (None, None), 'match_triples') # batch encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] #use_triples = tf.reduce_sum(tf.cast(tf.greater_equal(self.match_triples, 0), tf.float32), axis=-1) one_hot_triples = tf.one_hot(self.match_triples, triple_num) use_triples = tf.reduce_sum(one_hot_triples, axis=[2]) self.symbol2index = MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable( key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) self.entity2index = MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) self.index2entity = MutableHashTable( key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) # build the vocab table (string to index) self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len self.posts_entity_id = self.entity2index.lookup(self.posts) # batch*len #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) self.responses_target = self.symbol2index.lookup(self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID, tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # initialize the embedding randomly self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: # initialize the embedding by pre-trained word vectors self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0) triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, 3 * num_trans_units]) entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, num_symbols, num_samples) with tf.variable_scope('decoder'): # get attention function attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2]) #self.alignments = tf.Print(self.alignments, [self.alignments], summarize=1e8) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') #self.decoder_loss = tf.Print(self.decoder_loss, ['decoder_loss', self.decoder_loss], summarize=1e6) else: self.decoder_loss, self.sentence_ppx = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') with tf.variable_scope('decoder', reuse=True): # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=entities_word_embedding, selector_fn=selector_fn) self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_inference, scope="decoder_rnn") if output_alignments: output_len = tf.shape(self.decoder_distribution)[1] output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len))) word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) self.generation = tf.where(output_ids > 0, words, entities, name='generation') else: self.generation_index = tf.argmax(self.decoder_distribution, 2) self.generation = self.index2symbol.lookup(self.generation_index, name='generation') # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.params = tf.global_variables() # calculate the gradient of parameters #opt = tf.train.GradientDescentOptimizer(self.learning_rate) opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__( self, num_symbols, # 词汇表size num_embed_units, # 词嵌入size num_units, # RNN 每层单元数 num_layers, # RNN 层数 embed, # 词嵌入 entity_embed=None, # 实体+关系的嵌入 num_entities=0, # 实体+关系的总个数 num_trans_units=100, # 实体嵌入的维度 memory_units=100, learning_rate=0.0001, # 学习率 learning_rate_decay_factor=0.95, # 学习率衰退,并没有采用这种方式 max_gradient_norm=5.0, # num_samples=500, # 样本个数,sampled softmax max_length=60, mem_use=True, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # [batch_size, encoder_len] self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # [batch_size] self.responses = tf.placeholder( tf.string, (None, None), 'dec_inps') # [batch_size, decoder_len] self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # [batch_size] self.entities = tf.placeholder( tf.string, (None, None, None), 'entities') # [batch_size, triple_num, triple_len] self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # 没用到 self.triples = tf.placeholder( tf.string, (None, None, None, 3), 'triples') # [batch_size, triple_num, triple_len, 3] self.posts_triple = tf.placeholder( tf.int32, (None, None, 1), 'enc_triples') # [batch_size, encoder_len, 1] self.responses_triple = tf.placeholder( tf.string, (None, None, 3), 'dec_triples') # [batch_size, decoder_len, 3] self.match_triples = tf.placeholder( tf.int32, (None, None, None), 'match_triples') # [batch_size, decoder_len, triple_num] # 编码器batch_size,编码器encoder_len encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] # 知识图个数 triple_len = tf.shape(self.triples)[2] # 知识三元组个数 # 使用的知识三元组 one_hot_triples = tf.one_hot( self.match_triples, triple_len) # [batch_size, decoder_len, triple_num, triple_len] # 用 1 标注了哪个时间步产生的回复用了知识三元组 use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # [batch_size, decoder_len] # 词汇映射到index的hash table self.symbol2index = MutableHashTable( key_dtype=tf.string, # key张量的类型 value_dtype=tf.int64, # value张量的类型 default_value=UNK_ID, # 缺少key的默认值 shared_name= "in_table", # If non-empty, this table will be shared under the given name across multiple sessions name="in_table", # 操作名 checkpoint=True ) # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name. # index映射到词汇的hash table self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # 实体映射到index的hash table self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) # index映射到实体的hash table self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) self.posts_word_id = self.symbol2index.lookup( self.posts) # [batch_size, encoder_len] self.posts_entity_id = self.entity2index.lookup( self.posts) # [batch_size, encoder_len] self.responses_target = self.symbol2index.lookup( self.responses) # [batch_size, decoder_len] # 获得解码器的batch_size,decoder_len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] # 去掉responses_target的最后一列,给第一列加上GO_ID self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # [batch_size, decoder_len] # 得到response的mask self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # [batch_size, decoder_len] # 初始化词嵌入和实体嵌入,传入了参数就直接赋值,没有的话就随机初始化 if embed is None: self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # 实体嵌入不随着模型的训练而更新 self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) # 将实体嵌入传入一个全连接层 self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') # 添加['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T']这7个的嵌入 padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) # triples_embedding: [batch_size, triple_num, triple_len, 3*num_trans_units] 知识图三元组的嵌入 triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) # entities_word_embedding: [batch_size, triple_num*triple_len, num_embed_units] 知识图中用到的所有实体的嵌入 entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) # 分离知识图三元组的头、关系和尾 [batch_size, triple_num, triple_len, num_trans_units] head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 静态图注意力机制 with tf.variable_scope('graph_attention'): # 将头尾连接起来 [batch_size, triple_num, triple_len, 2*num_trans_units] head_tail = tf.concat([head, tail], axis=3) # 将头尾送入全连接层 [batch_size, triple_num, triple_len, num_trans_units] head_tail_transformed = tf.layers.dense(head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform') # 将关系送入全连接层 [batch_size, triple_num, triple_len, num_trans_units] relation_transformed = tf.layers.dense(relation, num_trans_units, name='relation_transform') # 求头尾和关系两个向量的内积,获得对三元组的注意力系数 e_weight = tf.reduce_sum( relation_transformed * head_tail_transformed, axis=3) # [batch_size, triple_num, triple_len] alpha_weight = tf.nn.softmax( e_weight) # [batch_size, triple_num, triple_len] # tf.expand_dims 使 alpha_weight 维度+1 [batch_size, triple_num, triple_len, 1] # 对第2个维度求和,由此产生静态图的向量表示 graph_embed = tf.reduce_sum( tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) # [batch_size, triple_num, 2*num_trans_units] """graph_embed_input 1、首先一维的range列表[0, 1, 2... encoder_batch_size个]转化成三维的[encoder_batch_size, 1, 1]的矩阵 [[[0]], [[1]], [[2]],...] 2、然后tf.tile将矩阵的第1维复制encoder_len遍,变成[encoder_batch_size, encoder_len, 1] [[[0],[0]...]],...] 3、与posts_triple: [batch_size, encoder_len, 1]在第2维上进行拼接,形成一个indices: [batch_size, encoder_len, 2]矩阵, indices矩阵: [ [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len], [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len], [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len] ,...batch_size ] 4、tf.gather_nd根据索引检索graph_embed: [batch_size, triple_num, 2*num_trans_units]再回填至indices矩阵 indices矩阵最后一个维度是2,例如有[0, 2],表示这个时间步第1个batch用了第2个图, 则找到这个知识图的静态图向量填入到indices矩阵的[0, 2]位置最后得到结果维度 [encoder_batch_size, encoder_len, 2*num_trans_units]表示每个时间步用的静态图向量 """ # graph_embed_input = tf.gather_nd(graph_embed, tf.concat( # [tf.tile(tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), # self.posts_triple], # axis=2)) # 将responses_triple转化成实体嵌入 [batch_size, decoder_len, 300],标识了response每个时间步用了哪个三元组的嵌入 # triple_embed_input = tf.reshape( # tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.responses_triple)), # [batch_size, decoder_len, 3 * num_trans_units]) post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # [batch_size, encoder_len, num_embed_units] response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id ) # [batch_size, decoder_len, num_embed_units] # post_word_input和graph_embed_input拼接构成编码器输入 [batch_size, encoder_len, num_embed_units+2*num_trans_units] # self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) # response_word_input和triple_embed_input拼接构成解码器输入 [batch_size, decoder_len, num_embed_units+3*num_trans_units] # self.decoder_input = tf.concat([response_word_input, triple_embed_input], axis=2) encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder # encoder_state: [num_layers, 2, batch_size, num_units] 编码器输出状态 LSTM GRU:[num_layers, batch_size, num_units] encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, post_word_input, self.posts_length, dtype=tf.float32, scope="encoder") # self.encoder_state_shape = tf.shape(encoder_state) ########记忆网络 ### response_encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) response_encoder_output, response_encoder_state = tf.nn.dynamic_rnn( response_encoder_cell, response_word_input, self.responses_length, dtype=tf.float32, scope="response_encoder") # graph_embed: [batch_size, triple_num, 2*num_trans_units] 静态图向量 # encoder_state: [num_layers, batch_size, num_units] with tf.variable_scope("post_memory_network"): # 将静态知识图转化成输入向量m post_input = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="post_weight_a") post_input = tf.tile( tf.reshape(post_input, (1, encoder_batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将静态知识库转化成输出向量c post_output = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="post_weight_c") post_output = tf.tile( tf.reshape(post_output, (1, encoder_batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将question转化成状态向量u encoder_hidden_state = tf.reshape( tf.concat(encoder_state, axis=0), (num_layers, encoder_batch_size, num_units)) post_state = tf.layers.dense(encoder_hidden_state, memory_units, use_bias=False, name="post_weight_b") post_state = tf.tile( tf.reshape(post_state, (num_layers, encoder_batch_size, 1, memory_units)), multiples=( 1, 1, triple_num, 1)) # [num_layers, batch_size, triple_num, memory_units] # 概率p post_p = tf.reshape( tf.nn.softmax(tf.reduce_sum(post_state * post_input, axis=3)), (num_layers, encoder_batch_size, triple_num, 1)) # [num_layers, batch_size, triple_num, 1] # 输出o post_o = tf.reduce_sum( post_output * post_p, axis=2) # [num_layers, batch_size, memory_units] post_xstar = tf.concat( [ tf.layers.dense(post_o, memory_units, use_bias=False, name="post_weight_r"), encoder_state ], axis=2) # [num_layers, batch_size, num_units+memory_units] with tf.variable_scope("response_memory_network"): # 将静态知识图转化成输入向量m response_input = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="response_weight_a") response_input = tf.tile( tf.reshape(response_input, (1, batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将静态知识库转化成输出向量c response_output = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="response_weight_c") response_output = tf.tile( tf.reshape(response_output, (1, batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将question转化成状态向量u response_hidden_state = tf.reshape( tf.concat(response_encoder_state, axis=0), (num_layers, batch_size, num_units)) response_state = tf.layers.dense(response_hidden_state, memory_units, use_bias=False, name="response_weight_b") response_state = tf.tile( tf.reshape(response_state, (num_layers, batch_size, 1, memory_units)), multiples=( 1, 1, triple_num, 1)) # [num_layers, batch_size, triple_num, memory_units] # 概率p response_p = tf.reshape( tf.nn.softmax( tf.reduce_sum(response_state * response_input, axis=3)), (num_layers, batch_size, triple_num, 1)) # [num_layers, batch_size, triple_num, 1] # 输出o response_o = tf.reduce_sum( response_output * response_p, axis=2) # [num_layers, batch_size, memory_units] response_ystar = tf.concat( [ tf.layers.dense(response_o, memory_units, use_bias=False, name="response_weight_r"), response_encoder_state ], axis=2) # [num_layers, batch_size, num_units+memory_units] with tf.variable_scope("memory_network"): memory_hidden_state = tf.layers.dense(tf.concat( [post_xstar, response_ystar], axis=2), num_units, use_bias=False, activation=tf.tanh, name="output_weight") memory_hidden_state = tf.reshape( memory_hidden_state, (num_layers * batch_size, num_units)) # [num_layers, batch_size, num_units] memory_hidden_state = tuple( tf.split(memory_hidden_state, [batch_size] * num_layers, axis=0)) # self.memory_hidden_state_shape = tf.shape(memory_hidden_state) ######## ### output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss =\ output_projection_layer(num_units, num_symbols, num_samples) ########用于训练的decoder ### with tf.variable_scope('decoder'): attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 训练时处理每个时间步输出和下个时间步输入的函数 decoder_fn_train = attention_decoder_fn_train( memory_hidden_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, response_word_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2, 3]) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, name='ppx_loss') else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) ######## ### ########用于推导的decoder ### with tf.variable_scope('decoder', reuse=True): attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) decoder_fn_inference = \ attention_decoder_fn_inference(output_fn, memory_hidden_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, # imem: ([batch_size,triple_num*triple_len,num_embed_units], tf.reshape(triples_embedding, [encoder_batch_size, -1, 3*num_trans_units])), # [encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体词嵌入和三元组嵌入的元组 selector_fn=selector_fn) # decoder_distribution: [batch_size, decoder_len, num_symbols] # output_ids_ta: tensorarray: decoder_len [batch_size] self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] # decoder_len output_ids = tf.transpose( output_ids_ta.gather( tf.range(output_len))) # [batch_size, decoder_len] # 对output的值域行裁剪,因为存在负值表示用了实体词 word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) # [batch_size, decoder_len] # 计算的是实体词在entities中的实际位置 [batch_size, decoder_len] # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len # 2、tf.range(encoder_batch_size): [batch_size] # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在entities中的基地址 # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词在entities中的偏移量 # 5、entity_ids: [batch_size, decoder_len] 实体词在entities中的实际位置 entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) # 计算的是所用的实体词 [batch_size, decoder_len] # 1、entities: [batch_size, triple_num, triple_len] # 2、tf.reshape(self.entities, [-1]): [batch_size*triple_num*triple_len] # 3、tf.gather: [batch_size*decoder_len] # 4、entities: [batch_size, decoder_len] entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 将id转化为实际的词 # output_ids>0为bool张量,True的位置用words中该位置的词替换 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity( self.generation, name='generation') # [batch_size, decoder_len] ######## ### # 初始化训练过程 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) # 并没有使用衰退的学习率 self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) # 更新参数的次数 self.global_step = tf.Variable(0, trainable=False) # 要训练的参数 self.params = tf.global_variables() # 选择优化算法 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr # 根据 decoder_loss 计算 params 梯度 gradients = tf.gradients(self.decoder_loss, self.params) # 梯度裁剪 clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # 记录损失 tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) # 记录变量的训练情况 self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__( self, num_symbols, # 词汇表size num_embed_units, # 词嵌入size num_units, # RNN 每层单元数 num_layers, # RNN 层数 embed, # 词嵌入 entity_embed=None, # num_entities=0, # num_trans_units=100, # learning_rate=0.0001, learning_rate_decay_factor=0.95, # max_gradient_norm=5.0, # num_samples=500, # 样本个数,sampled softmax max_length=60, mem_use=True, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch_size * encoder_len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch_size self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch_size * decoder_len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch_size self.entities = tf.placeholder( tf.string, (None, None, None), 'entities') # batch_size * triple_num * triple_len self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # 没用到 self.triples = tf.placeholder( tf.string, (None, None, None, 3), 'triples') # batch_size * triple_num * triple_len * 3 self.posts_triple = tf.placeholder( tf.int32, (None, None, 1), 'enc_triples') # batch_size * encoder_len self.responses_triple = tf.placeholder( tf.string, (None, None, 3), 'dec_triples') # batch_size * decoder_len * 3 self.match_triples = tf.placeholder( tf.int32, (None, None, None), 'match_triples') # batch_size * decoder_len * triple_num # 获得 encoder_batch_size ,编码器的 encoder_len encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) # 获得 triple_num # 每个 post 包含的知识图个数(补齐过的) triple_num = tf.shape(self.triples)[1] # 获得 triple_len # 每个知识图包含的关联实体个数(补齐过的) triple_len = tf.shape(self.triples)[2] # 使用的知识三元组 one_hot_triples = tf.one_hot( self.match_triples, triple_len) # batch_size * decoder_len * triple_num * triple_len # 用 1 标注了哪个时间步产生的回复用了知识三元组 use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # batch_size * decoder_len # 词汇映射到 index 的 hash table self.symbol2index = MutableHashTable( key_dtype=tf.string, # key张量的类型 value_dtype=tf.int64, # value张量的类型 default_value=UNK_ID, # 缺少key的默认值 shared_name= "in_table", # If non-empty, this table will be shared under the given name across multiple sessions name="in_table", # 操作名 checkpoint=True ) # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name. # index 映射到词汇的 hash table self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # 实体映射到 index 的 hash table self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) # index 映射到实体的 hash table self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) # 将 post 的 string 映射成词汇 id self.posts_word_id = self.symbol2index.lookup( self.posts) # batch_size * encoder_len # 将 post 的 string 映射成实体 id self.posts_entity_id = self.entity2index.lookup( self.posts) # batch_size * encoder_len # 将 response 的 string 映射成词汇 id self.responses_target = self.symbol2index.lookup( self.responses) # batch_size * decoder_len # 获得解码器的 batch_size,decoder_len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] # 去掉 responses_target 的最后一列,给第一列加上 GO_ID self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch_size * decoder_len # 得到 response 的 mask # 首先将回复的长度 one_hot 编码 # 然后横着从右向左累计求和,形成一个如果该位置在长度范围内,则为1,否则则为0的矩阵,最后一步 reshape 应该没有必要 self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # batch_size * decoder_len # 初始化 词嵌入 和 实体嵌入,传入了参数就直接赋值,没有的话就随机初始化 if embed is None: self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) # 添加一个全连接层,输入是实体的嵌入,该层的 size=num_trans_units,激活函数是tanh # 为什么还要用全连接层连一下?????? self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') # 7 * num_trans_units 的全零初始化的数组 padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) # 把 padding_entity 添加到 entity_trans_transformed 的最前,补了有什么用????????????? self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) # tf.nn.embedding_lookup 以后维度会+1,所以通过reshape来取消这个多出来的维度 triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units ]) # [batch_size,triple_num*triple_len,num_embed_units] # 把 head,relation,tail分割开来 head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 静态图注意力机制 with tf.variable_scope('graph_attention'): # 将头和尾连接起来 head_tail = tf.concat( [head, tail], axis=3) # batch_size * triple_num * triple_len * 200 # tanh(dot(W, head_tail)) head_tail_transformed = tf.layers.dense( head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform' ) # batch_size * triple_num * triple_len * 100 # dot(W, relation) relation_transformed = tf.layers.dense( relation, num_trans_units, name='relation_transform' ) # batch_size * triple_num * triple_len * 100 # 两个向量先元素乘,再求和,等于两个向量的内积 # dot(traspose(dot(W, relation)), tanh(dot(W, head_tail))) e_weight = tf.reduce_sum( relation_transformed * head_tail_transformed, axis=3) # batch_size * triple_num * triple_len # 图中每个三元组的 alpha 权值 alpha_weight = tf.nn.softmax( e_weight) # batch_size * triple_num * triple_len # tf.expand_dims 使 alpha_weight 维度+1 batch_size * triple_num * triple_len * 1 # 对第2个维度求和,由此产生每个图 100 维的图向量表示 graph_embed = tf.reduce_sum( tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) # batch_size * triple_num * 100 """ [0, 1, 2... encoder_batch_size] 转化成 encoder_batch_size * 1 * 1 的矩阵 [[[0]], [[1]], [[2]],...] tf.tile 将矩阵的第 1 维进行扩展 encoder_batch_size * encoder_len * 1 [[[0],[0]...]],...] 与 posts_triple 在第 2 维度上进行拼接,形成 indices 矩阵 indices 矩阵: [ [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len], [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len], [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len] ,...batch_size ] tf.gather_nd 将 graph_embed 中根据上面矩阵提供的索引检索图向量,再回填至 indices 矩阵 encoder_batch_size * encoder_len * 100 """ graph_embed_input = tf.gather_nd( graph_embed, tf.concat([ tf.tile( tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple ], axis=2)) # 将 responses_triple 转化成实体嵌入 batch_size * decoder_len * 300 triple_embed_input = tf.reshape( tf.nn.embedding_lookup( self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * num_trans_units]) # 将 posts_word_id 转化成词嵌入 post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # batch_size * encoder_len * 300 # 将 responses_word_id 转化成词嵌入 response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id) # batch_size * decoder_len * 300 # post_word_input, graph_embed_input 在第二个维度上拼接 self.encoder_input = tf.concat( [post_word_input, graph_embed_input], axis=2) # batch_size * encoder_len * 400 # response_word_input, triple_embed_input 在第二个维度上拼接 self.decoder_input = tf.concat( [response_word_input, triple_embed_input], axis=2) # batch_size * decoder_len * 600 # 构造 deep RNN encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # 由于词汇表维度过大,所以输出的维度不可能和词汇表一样。通过 projection 函数,可以实现从低维向高维的映射 # 返回:输出函数,选择器函数,计算序列损失,采样序列损失,总体损失的函数 output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer( num_units, num_symbols, num_samples) # 用于训练的 decoder with tf.variable_scope('decoder'): # 得到注意力函数 # 准备注意力 # attention_keys_init: 注意力的 keys # attention_values_init: 注意力的 values # attention_score_fn_init: 计算注意力上下文的函数 # attention_construct_fn_init: 计算所有上下文拼接的函数 attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) # 返回训练时解码器每一个时间步对输入的处理函数 decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) # 输出,最终状态,alignments 的 TensorArray self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity( self.sentence_ppx, name='ppx_loss') # 将 sentence_ppx 转化成一步操作 else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # 用于推导的 decoder with tf.variable_scope('decoder', reuse=True): # 得到注意力函数 attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, tf.reshape( triples_embedding, [encoder_batch_size, -1, 3 * num_trans_units])), selector_fn=selector_fn) # imem: ([batch_size,triple_num*triple_len,num_embed_units],[encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体次嵌入和三元组嵌入的元组 self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] # decoder_len output_ids = tf.transpose( output_ids_ta.gather( tf.range(output_len))) # [batch_size, decoder_len] # 对 output 的值域行裁剪 word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) # [batch_size, decoder_len] # 计算的是采用的实体词在 entities 的位置 # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len # 2、tf.range(encoder_batch_size): [batch_size] # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在 entities 中的偏移量 # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词的相对位置 # 5、entity_ids: [batch_size * decoder_len] 加上偏移量之后在 entities 中的实际位置 entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) # 计算的是所用的实体词 # 1、entities: [batch_size, triple_num, triple_len] # 2、tf.reshape(self.entities, [-1]): [batch_size * triple_num * triple_len] # 3、tf.gather: [batch_size*decoder_len] # 4、entities: [batch_size, output_len] entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 将 id 转化为实际的词 # output_ids > 0 为 bool 张量,True 的位置用 words 中该位置的词替换 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity(self.generation, name='generation') # 初始化训练过程 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) # ??? self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) # 更新参数的次数 self.global_step = tf.Variable(0, trainable=False) # 要训练的参数 self.params = tf.global_variables() # 选择优化算法 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr # 根据 decoder_loss 计算 params 梯度 gradients = tf.gradients(self.decoder_loss, self.params) # 梯度裁剪 clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, embed, entity_embed=None, num_entities=0, num_trans_units=100, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=500, max_length=60, mem_use=True, output_alignments=True, use_lstm=False): # 输入数据占位定义 self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch self.entities = tf.placeholder(tf.string, (None, None, None), 'entities') # batch self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch self.triples = tf.placeholder(tf.string, (None, None, None, 3), 'triples') # batch self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch self.match_triples = tf.placeholder(tf.int32, (None, None, None), 'match_triples') # batch encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] triple_len = tf.shape(self.triples)[2] one_hot_triples = tf.one_hot(self.match_triples, triple_len) use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # 构建词汇查询talbe (index to string, string to index) self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len self.posts_entity_id = self.entity2index.lookup( self.posts) # batch*len #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) self.responses_target = self.symbol2index.lookup( self.responses) # batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # 构建词嵌入 table (index to vector) if embed is None: # 随机初始化词嵌入 self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec) self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # 随机初始化词嵌入 self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec) self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 知识融合层的静态注意力 with tf.variable_scope('graph_attention'): # 拼接head tail head_tail = tf.concat([head, tail], axis=3) # head tail合成一个向量 head_tail_transformed = tf.layers.dense(head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform') # relation 向量 relation_transformed = tf.layers.dense(relation, num_trans_units, name='relation_transform') # relation 和 head_tail 计算注意力权重 e_weight = tf.reduce_sum(relation_transformed * head_tail_transformed, axis=3) # 将注意力权重归一化 alpha_weight = tf.nn.softmax(e_weight) # 将权重和head_tail进行加权求和 graph_embed = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) graph_embed_input = tf.gather_nd( graph_embed, tf.concat([ tf.tile( tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple ], axis=2)) triple_embed_input = tf.reshape( tf.nn.embedding_lookup( self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * num_trans_units]) post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # batch*len*unit response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id) # batch*len*unit # 在输入语句中拼接注意力机制计算出来的图谱信息 self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) # 在输出语句中拼接所有图谱信息 self.decoder_input = tf.concat( [response_word_input, triple_embed_input], axis=2) # 编码器使用GRUCell, num_layers为网络层数 encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # 解码器层使用GRUCell,num_layers为网络层数 decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # RNN编码器的包装 encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer( num_units, num_symbols, num_samples) # 解码器 with tf.variable_scope('decoder'): # 获取 attention 函数 attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 'luong', num_units) decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2, 3]) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, name='ppx_loss') else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) with tf.variable_scope('decoder', reuse=True): # 获取 attention 函数 attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, tf.reshape( triples_embedding, [encoder_batch_size, -1, 3 * num_trans_units])), selector_fn=selector_fn) self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] output_ids = tf.transpose( output_ids_ta.gather(tf.range(output_len))) word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 生成用于输出的回复语句 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity(self.generation, name='generation') # 训练参数初始化 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.params = tf.global_variables() # 使用Adam优化器,计算高效、梯度平滑、参数调节简单 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)