def __init__(self, num_symbols, num_embed_units, num_units, num_layers, embed, entity_embed=None, num_entities=0, num_trans_units=100, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=60, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch self.entities = tf.placeholder(tf.string, (None, None), 'entities') # batch self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch self.triples = tf.placeholder(tf.string, (None, None, 3), 'triples') # batch self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch self.match_triples = tf.placeholder(tf.int32, (None, None), 'match_triples') # batch encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] #use_triples = tf.reduce_sum(tf.cast(tf.greater_equal(self.match_triples, 0), tf.float32), axis=-1) one_hot_triples = tf.one_hot(self.match_triples, triple_num) use_triples = tf.reduce_sum(one_hot_triples, axis=[2]) self.symbol2index = MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable( key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) self.entity2index = MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) self.index2entity = MutableHashTable( key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) # build the vocab table (string to index) self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len self.posts_entity_id = self.entity2index.lookup(self.posts) # batch*len #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) self.responses_target = self.symbol2index.lookup(self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID, tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # initialize the embedding randomly self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: # initialize the embedding by pre-trained word vectors self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0) triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, 3 * num_trans_units]) entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, num_symbols, num_samples) with tf.variable_scope('decoder'): # get attention function attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2]) #self.alignments = tf.Print(self.alignments, [self.alignments], summarize=1e8) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') #self.decoder_loss = tf.Print(self.decoder_loss, ['decoder_loss', self.decoder_loss], summarize=1e6) else: self.decoder_loss, self.sentence_ppx = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') with tf.variable_scope('decoder', reuse=True): # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=entities_word_embedding, selector_fn=selector_fn) self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_inference, scope="decoder_rnn") if output_alignments: output_len = tf.shape(self.decoder_distribution)[1] output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len))) word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) self.generation = tf.where(output_ids > 0, words, entities, name='generation') else: self.generation_index = tf.argmax(self.decoder_distribution, 2) self.generation = self.index2symbol.lookup(self.generation_index, name='generation') # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.params = tf.global_variables() # calculate the gradient of parameters #opt = tf.train.GradientDescentOptimizer(self.learning_rate) opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, learning_rate=0.1, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=True): self.posts_1 = tf.placeholder(tf.string, shape=(None, None)) self.posts_2 = tf.placeholder(tf.string, shape=(None, None)) self.posts_3 = tf.placeholder(tf.string, shape=(None, None)) self.posts_4 = tf.placeholder(tf.string, shape=(None, None)) self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_mask_1 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_2 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_3 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_4 = tf.placeholder(tf.float32, shape=(None, None, None)) self.posts_length_1 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_2 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_3 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_4 = tf.placeholder(tf.int32, shape=(None)) self.responses = tf.placeholder(tf.string, shape=(None, None)) self.responses_length = tf.placeholder(tf.int32, shape=(None)) self.epoch = tf.Variable(0, trainable=False, name='epoch') self.epoch_add_op = self.epoch.assign(self.epoch + 1) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.'] * num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input_1 = self.symbol2index.lookup(self.posts_1) self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup( self.posts_2) self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup( self.posts_3) self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup( self.posts_4) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.posts_input_2 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1], 1)[0] ], 1) self.posts_input_3 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1], 1)[0] ], 1) self.posts_input_4 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1], 1)[0] ], 1) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) self.encoder_2_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_2 - 1, tf.shape(self.posts_2)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_2)[1]]) self.encoder_3_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_3 - 1, tf.shape(self.posts_3)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_3)[1]]) self.encoder_4_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_4 - 1, tf.shape(self.posts_4)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_4)[1]]) self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) if embed is None: self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input_1 = tf.nn.embedding_lookup(self.embed, self.posts_input_1) self.encoder_input_2 = tf.nn.embedding_lookup(self.embed, self.posts_input_2) self.encoder_input_3 = tf.nn.embedding_lookup(self.embed, self.posts_input_3) self.encoder_input_4 = tf.nn.embedding_lookup(self.embed, self.posts_input_4) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) entity_embedding_1 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_1)), [ batch_size, tf.shape(self.entity_1)[1], tf.shape(self.entity_1)[2], 3 * num_embed_units ]) entity_embedding_2 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_2)), [ batch_size, tf.shape(self.entity_2)[1], tf.shape(self.entity_2)[2], 3 * num_embed_units ]) entity_embedding_3 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_3)), [ batch_size, tf.shape(self.entity_3)[1], tf.shape(self.entity_3)[2], 3 * num_embed_units ]) entity_embedding_4 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_4)), [ batch_size, tf.shape(self.entity_4)[1], tf.shape(self.entity_4)[2], 3 * num_embed_units ]) head_1, relation_1, tail_1 = tf.split(entity_embedding_1, [num_embed_units] * 3, axis=3) head_2, relation_2, tail_2 = tf.split(entity_embedding_2, [num_embed_units] * 3, axis=3) head_3, relation_3, tail_3 = tf.split(entity_embedding_3, [num_embed_units] * 3, axis=3) head_4, relation_4, tail_4 = tf.split(entity_embedding_4, [num_embed_units] * 3, axis=3) with tf.variable_scope('graph_attention'): #[batch_size, max_reponse_length, max_triple_num, 2*embed_units] head_tail_1 = tf.concat([head_1, tail_1], axis=3) #[batch_size, max_reponse_length, max_triple_num, embed_units] head_tail_transformed_1 = tf.layers.dense( head_tail_1, num_embed_units, activation=tf.tanh, name='head_tail_transform') #[batch_size, max_reponse_length, max_triple_num, embed_units] relation_transformed_1 = tf.layers.dense(relation_1, num_embed_units, name='relation_transform') #[batch_size, max_reponse_length, max_triple_num] e_weight_1 = tf.reduce_sum(relation_transformed_1 * head_tail_transformed_1, axis=3) #[batch_size, max_reponse_length, max_triple_num] alpha_weight_1 = tf.nn.softmax(e_weight_1) #[batch_size, max_reponse_length, embed_units] graph_embed_1 = tf.reduce_sum( tf.expand_dims(alpha_weight_1, 3) * (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_2 = tf.concat([head_2, tail_2], axis=3) head_tail_transformed_2 = tf.layers.dense( head_tail_2, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_2 = tf.layers.dense(relation_2, num_embed_units, name='relation_transform') e_weight_2 = tf.reduce_sum(relation_transformed_2 * head_tail_transformed_2, axis=3) alpha_weight_2 = tf.nn.softmax(e_weight_2) graph_embed_2 = tf.reduce_sum( tf.expand_dims(alpha_weight_2, 3) * (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_3 = tf.concat([head_3, tail_3], axis=3) head_tail_transformed_3 = tf.layers.dense( head_tail_3, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_3 = tf.layers.dense(relation_3, num_embed_units, name='relation_transform') e_weight_3 = tf.reduce_sum(relation_transformed_3 * head_tail_transformed_3, axis=3) alpha_weight_3 = tf.nn.softmax(e_weight_3) graph_embed_3 = tf.reduce_sum( tf.expand_dims(alpha_weight_3, 3) * (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_4 = tf.concat([head_4, tail_4], axis=3) head_tail_transformed_4 = tf.layers.dense( head_tail_4, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_4 = tf.layers.dense(relation_4, num_embed_units, name='relation_transform') e_weight_4 = tf.reduce_sum(relation_transformed_4 * head_tail_transformed_4, axis=3) alpha_weight_4 = tf.nn.softmax(e_weight_4) graph_embed_4 = tf.reduce_sum( tf.expand_dims(alpha_weight_4, 3) * (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4), axis=2) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) encoder_output_1, encoder_state_1 = dynamic_rnn(cell, self.encoder_input_1, self.posts_length_1, dtype=tf.float32, scope="encoder") attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \ = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units) decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_1, attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1, max_length=tf.reduce_max(self.posts_length_2)) encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder( cell, decoder_fn_train_1, self.encoder_input_2, self.posts_length_2, scope="decoder") self.alignments_2 = tf.transpose(alignments_ta_2.stack(), perm=[1, 0, 2]) self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2, self.posts_2_target, self.encoder_2_mask) with variable_scope.variable_scope('', reuse=True): attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \ = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units) decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_2, attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2, max_length=tf.reduce_max(self.posts_length_3)) encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder( cell, decoder_fn_train_2, self.encoder_input_3, self.posts_length_3, scope="decoder") self.alignments_3 = tf.transpose(alignments_ta_3.stack(), perm=[1, 0, 2]) self.decoder_loss_3 = sampled_sequence_loss( encoder_output_3, self.posts_3_target, self.encoder_3_mask) attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \ = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units) decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_3, attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3, max_length=tf.reduce_max(self.posts_length_4)) encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder( cell, decoder_fn_train_3, self.encoder_input_4, self.posts_length_4, scope="decoder") self.alignments_4 = tf.transpose(alignments_ta_4.stack(), perm=[1, 0, 2]) self.decoder_loss_4 = sampled_sequence_loss( encoder_output_4, self.posts_4_target, self.encoder_4_mask) attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units) if is_train: with variable_scope.variable_scope('', reuse=True): decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2]) self.decoder_loss = sampled_sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask) self.params = tf.trainable_variables() self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) #opt = tf.train.GradientDescentOptimizer(self.learning_rate) opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) gradients = tf.gradients( self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 + self.decoder_loss_4, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: with variable_scope.variable_scope('', reuse=True): decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder") output_len = tf.shape(self.decoder_distribution)[1] self.alignments = tf.transpose( alignments_ta.gather(tf.range(output_len)), [1, 0, 2]) self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index, name="generation") self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=10, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, beam_size, embed, learning_rate=0.5, remove_unk=False, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=8, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # build the vocab table (string to index) self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup( self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) with tf.variable_scope('decoder'): decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) self.decoder_output, _, _ = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") self.decoder_loss = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) with tf.variable_scope('decoder', reuse=True): decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, _ = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder_rnn") self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = self.index2symbol.lookup(self.generation_index, name='generation') with tf.variable_scope('decoder', reuse=True): decoder_fn_beam_inference = attention_decoder_fn_beam_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, beam_size, remove_unk) _, _, self.context_state = dynamic_rnn_decoder( cell, decoder_fn_beam_inference, scope="decoder_rnn") (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) = self.context_state self.beam_parents = tf.transpose(tf.reshape( beam_parents.stack(), [max_length + 1, -1, beam_size]), [1, 0, 2], name='beam_parents') self.beam_symbols = tf.transpose( tf.reshape(beam_symbols.stack(), [max_length + 1, -1, beam_size]), [1, 0, 2]) self.beam_symbols = self.index2symbol.lookup(tf.cast( self.beam_symbols, tf.int64), name="beam_symbols") self.result_probs = tf.transpose(tf.reshape( result_probs.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2], name='result_probs') self.result_symbols = tf.transpose( tf.reshape(result_symbols.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2]) self.result_parents = tf.transpose(tf.reshape( result_parents.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2], name='result_parents') self.result_symbols = self.index2symbol.lookup( tf.cast(self.result_symbols, tf.int64), name='result_symbols') self.params = tf.trainable_variables() # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # Exporter for serving self.model_exporter = exporter.Exporter(self.saver) inputs = {"enc_inps:0": self.posts, "enc_lens:0": self.posts_length} outputs = { "beam_symbols": self.beam_symbols, "beam_parents": self.beam_parents, "result_probs": self.result_probs, "result_symbols": self.result_symbols, "result_parents": self.result_parents } self.model_exporter.init(tf.get_default_graph().as_graph_def(), named_graph_signatures={ "inputs": exporter.generic_signature(inputs), "outputs": exporter.generic_signature(outputs) })
def __init__( self, num_symbols, # 词汇表size num_embed_units, # 词嵌入size num_units, # RNN 每层单元数 num_layers, # RNN 层数 embed, # 词嵌入 entity_embed=None, # num_entities=0, # num_trans_units=100, # learning_rate=0.0001, learning_rate_decay_factor=0.95, # max_gradient_norm=5.0, # num_samples=500, # 样本个数,sampled softmax max_length=60, mem_use=True, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch_size * encoder_len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch_size self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch_size * decoder_len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch_size self.entities = tf.placeholder( tf.string, (None, None, None), 'entities') # batch_size * triple_num * triple_len self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # 没用到 self.triples = tf.placeholder( tf.string, (None, None, None, 3), 'triples') # batch_size * triple_num * triple_len * 3 self.posts_triple = tf.placeholder( tf.int32, (None, None, 1), 'enc_triples') # batch_size * encoder_len self.responses_triple = tf.placeholder( tf.string, (None, None, 3), 'dec_triples') # batch_size * decoder_len * 3 self.match_triples = tf.placeholder( tf.int32, (None, None, None), 'match_triples') # batch_size * decoder_len * triple_num # 获得 encoder_batch_size ,编码器的 encoder_len encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) # 获得 triple_num # 每个 post 包含的知识图个数(补齐过的) triple_num = tf.shape(self.triples)[1] # 获得 triple_len # 每个知识图包含的关联实体个数(补齐过的) triple_len = tf.shape(self.triples)[2] # 使用的知识三元组 one_hot_triples = tf.one_hot( self.match_triples, triple_len) # batch_size * decoder_len * triple_num * triple_len # 用 1 标注了哪个时间步产生的回复用了知识三元组 use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # batch_size * decoder_len # 词汇映射到 index 的 hash table self.symbol2index = MutableHashTable( key_dtype=tf.string, # key张量的类型 value_dtype=tf.int64, # value张量的类型 default_value=UNK_ID, # 缺少key的默认值 shared_name= "in_table", # If non-empty, this table will be shared under the given name across multiple sessions name="in_table", # 操作名 checkpoint=True ) # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name. # index 映射到词汇的 hash table self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # 实体映射到 index 的 hash table self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) # index 映射到实体的 hash table self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) # 将 post 的 string 映射成词汇 id self.posts_word_id = self.symbol2index.lookup( self.posts) # batch_size * encoder_len # 将 post 的 string 映射成实体 id self.posts_entity_id = self.entity2index.lookup( self.posts) # batch_size * encoder_len # 将 response 的 string 映射成词汇 id self.responses_target = self.symbol2index.lookup( self.responses) # batch_size * decoder_len # 获得解码器的 batch_size,decoder_len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] # 去掉 responses_target 的最后一列,给第一列加上 GO_ID self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch_size * decoder_len # 得到 response 的 mask # 首先将回复的长度 one_hot 编码 # 然后横着从右向左累计求和,形成一个如果该位置在长度范围内,则为1,否则则为0的矩阵,最后一步 reshape 应该没有必要 self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # batch_size * decoder_len # 初始化 词嵌入 和 实体嵌入,传入了参数就直接赋值,没有的话就随机初始化 if embed is None: self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) # 添加一个全连接层,输入是实体的嵌入,该层的 size=num_trans_units,激活函数是tanh # 为什么还要用全连接层连一下?????? self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') # 7 * num_trans_units 的全零初始化的数组 padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) # 把 padding_entity 添加到 entity_trans_transformed 的最前,补了有什么用????????????? self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) # tf.nn.embedding_lookup 以后维度会+1,所以通过reshape来取消这个多出来的维度 triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units ]) # [batch_size,triple_num*triple_len,num_embed_units] # 把 head,relation,tail分割开来 head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 静态图注意力机制 with tf.variable_scope('graph_attention'): # 将头和尾连接起来 head_tail = tf.concat( [head, tail], axis=3) # batch_size * triple_num * triple_len * 200 # tanh(dot(W, head_tail)) head_tail_transformed = tf.layers.dense( head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform' ) # batch_size * triple_num * triple_len * 100 # dot(W, relation) relation_transformed = tf.layers.dense( relation, num_trans_units, name='relation_transform' ) # batch_size * triple_num * triple_len * 100 # 两个向量先元素乘,再求和,等于两个向量的内积 # dot(traspose(dot(W, relation)), tanh(dot(W, head_tail))) e_weight = tf.reduce_sum( relation_transformed * head_tail_transformed, axis=3) # batch_size * triple_num * triple_len # 图中每个三元组的 alpha 权值 alpha_weight = tf.nn.softmax( e_weight) # batch_size * triple_num * triple_len # tf.expand_dims 使 alpha_weight 维度+1 batch_size * triple_num * triple_len * 1 # 对第2个维度求和,由此产生每个图 100 维的图向量表示 graph_embed = tf.reduce_sum( tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) # batch_size * triple_num * 100 """ [0, 1, 2... encoder_batch_size] 转化成 encoder_batch_size * 1 * 1 的矩阵 [[[0]], [[1]], [[2]],...] tf.tile 将矩阵的第 1 维进行扩展 encoder_batch_size * encoder_len * 1 [[[0],[0]...]],...] 与 posts_triple 在第 2 维度上进行拼接,形成 indices 矩阵 indices 矩阵: [ [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len], [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len], [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len] ,...batch_size ] tf.gather_nd 将 graph_embed 中根据上面矩阵提供的索引检索图向量,再回填至 indices 矩阵 encoder_batch_size * encoder_len * 100 """ graph_embed_input = tf.gather_nd( graph_embed, tf.concat([ tf.tile( tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple ], axis=2)) # 将 responses_triple 转化成实体嵌入 batch_size * decoder_len * 300 triple_embed_input = tf.reshape( tf.nn.embedding_lookup( self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * num_trans_units]) # 将 posts_word_id 转化成词嵌入 post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # batch_size * encoder_len * 300 # 将 responses_word_id 转化成词嵌入 response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id) # batch_size * decoder_len * 300 # post_word_input, graph_embed_input 在第二个维度上拼接 self.encoder_input = tf.concat( [post_word_input, graph_embed_input], axis=2) # batch_size * encoder_len * 400 # response_word_input, triple_embed_input 在第二个维度上拼接 self.decoder_input = tf.concat( [response_word_input, triple_embed_input], axis=2) # batch_size * decoder_len * 600 # 构造 deep RNN encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # 由于词汇表维度过大,所以输出的维度不可能和词汇表一样。通过 projection 函数,可以实现从低维向高维的映射 # 返回:输出函数,选择器函数,计算序列损失,采样序列损失,总体损失的函数 output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer( num_units, num_symbols, num_samples) # 用于训练的 decoder with tf.variable_scope('decoder'): # 得到注意力函数 # 准备注意力 # attention_keys_init: 注意力的 keys # attention_values_init: 注意力的 values # attention_score_fn_init: 计算注意力上下文的函数 # attention_construct_fn_init: 计算所有上下文拼接的函数 attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) # 返回训练时解码器每一个时间步对输入的处理函数 decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) # 输出,最终状态,alignments 的 TensorArray self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity( self.sentence_ppx, name='ppx_loss') # 将 sentence_ppx 转化成一步操作 else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # 用于推导的 decoder with tf.variable_scope('decoder', reuse=True): # 得到注意力函数 attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, tf.reshape( triples_embedding, [encoder_batch_size, -1, 3 * num_trans_units])), selector_fn=selector_fn) # imem: ([batch_size,triple_num*triple_len,num_embed_units],[encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体次嵌入和三元组嵌入的元组 self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] # decoder_len output_ids = tf.transpose( output_ids_ta.gather( tf.range(output_len))) # [batch_size, decoder_len] # 对 output 的值域行裁剪 word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) # [batch_size, decoder_len] # 计算的是采用的实体词在 entities 的位置 # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len # 2、tf.range(encoder_batch_size): [batch_size] # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在 entities 中的偏移量 # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词的相对位置 # 5、entity_ids: [batch_size * decoder_len] 加上偏移量之后在 entities 中的实际位置 entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) # 计算的是所用的实体词 # 1、entities: [batch_size, triple_num, triple_len] # 2、tf.reshape(self.entities, [-1]): [batch_size * triple_num * triple_len] # 3、tf.gather: [batch_size*decoder_len] # 4、entities: [batch_size, output_len] entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 将 id 转化为实际的词 # output_ids > 0 为 bool 张量,True 的位置用 words 中该位置的词替换 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity(self.generation, name='generation') # 初始化训练过程 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) # ??? self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) # 更新参数的次数 self.global_step = tf.Variable(0, trainable=False) # 要训练的参数 self.params = tf.global_variables() # 选择优化算法 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr # 根据 decoder_loss 计算 params 梯度 gradients = tf.gradients(self.decoder_loss, self.params) # 梯度裁剪 clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__(self, num_symbols, num_qwords, #modify num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, question_data=True, learning_rate=0.5, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=False): self.posts = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.posts_length = tf.placeholder(tf.int32, shape=(None)) # batch self.responses = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=(None)) # batch self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol self.word_type = tf.placeholder(tf.int32, shape=(None)) #(batch * len) # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup(self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len #delete the last column of responses_target) and add 'GO at the front of it. self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len print "embedding..." # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: print len(vocab), len(embed), len(embed[0]) print embed # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) print "embedding finished" if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer(num_units, num_symbols, num_qwords, num_samples, question_data) print "encoder_output.shape:", encoder_output.get_shape() # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) # get decoding loop function decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn, self.keyword_tensor, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) if is_train: # rnn decoder self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") # calculate the loss of decoder # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output]) self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type) # building graph finished and get all parameters self.params = tf.trainable_variables() for item in tf.trainable_variables(): print item.name, item.get_shape() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference, scope="decoder") print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape()) self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)]) # generating the response self.generation_index = tf.argmax(tf.split(self.decoder_distribution, [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_symbols, num_embed_units, num_units, is_train, vocab=None, content_pos=None, rhetoric_pos = None, embed=None, learning_rate=0.1, learning_rate_decay_factor=0.9995, max_gradient_norm=5.0, max_length=30, latent_size=128, use_lstm=False, num_classes=3, full_kl_step=80000, mem_slot_num=4, mem_size=128): self.ori_sents = tf.placeholder(tf.string, shape=(None, None)) self.ori_sents_length = tf.placeholder(tf.int32, shape=(None)) self.rep_sents = tf.placeholder(tf.string, shape=(None, None)) self.rep_sents_length = tf.placeholder(tf.int32, shape=(None)) self.labels = tf.placeholder(tf.float32, shape=(None, num_classes)) self.use_prior = tf.placeholder(tf.bool) self.global_t = tf.placeholder(tf.int32) self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0) self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0) topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32, shape=[None, mem_slot_num, mem_size]) w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32, initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1)) # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.ori_sents_input = self.symbol2index.lookup(self.ori_sents) self.rep_sents_target = self.symbol2index.lookup(self.rep_sents) batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1] self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1) self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input) if use_lstm: cell_fw = LSTMCell(num_units) cell_bw = LSTMCell(num_units) cell_dec = LSTMCell(2*num_units) else: cell_fw = GRUCell(num_units) cell_bw = GRUCell(num_units) cell_dec = GRUCell(2*num_units) # origin sentence encoder with variable_scope.variable_scope("encoder"): encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, self.ori_sents_length, dtype=tf.float32) post_sum_state = tf.concat(encoder_state, 1) encoder_output = tf.concat(encoder_output, 2) # response sentence encoder with variable_scope.variable_scope("encoder", reuse = True): decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, self.rep_sents_length, dtype=tf.float32) response_sum_state = tf.concat(decoder_last_state, 1) # recognition network with variable_scope.variable_scope("recog_net"): recog_input = tf.concat([post_sum_state, response_sum_state], 1) recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) # prior network with variable_scope.variable_scope("prior_net"): prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1") prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) latent_sample = tf.cond(self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) # classifier with variable_scope.variable_scope("classifier"): classifier_input = latent_sample pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1") self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits") self.label_embedding = tf.matmul(self.labels, self.pattern_embed) output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask) attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units) with variable_scope.variable_scope("dec_start"): temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1) dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1") dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2") if is_train: # rnn decoder topic_memory = self.update_memory(topic_memory, encoder_output) extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1) decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info) self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, self.decoder_input, self.rep_sents_length, scope = "decoder") # calculate the loss self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, targets = self.rep_sents_target, weights = self.decoder_mask, extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss) temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)) self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0) self.klloss = self.kl_weight * temp_klloss temp_labels = tf.argmax(self.labels, 1) self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels)) self.loss = self.decoder_loss + self.klloss + self.classifierloss # need to anneal the kl_weight # building graph finished and get all parameters self.params = tf.trainable_variables() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder topic_memory = self.update_memory(topic_memory, encoder_output) extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1) decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, dec_fc2, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info) self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder") self.generation_index = tf.argmax(tf.split(self.decoder_distribution, [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_items, num_embed_units, num_units, num_layers, embed=None, learning_rate=1e-4, action_num=10, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, use_lstm=True): self.epoch = tf.Variable(0, trainable=False, name='agn/epoch') self.epoch_add_op = self.epoch.assign(self.epoch + 1) self.sessions_input = tf.placeholder(tf.int32, shape=(None, None)) self.rec_lists = tf.placeholder(tf.int32, shape=(None, None, None)) self.rec_mask = tf.placeholder(tf.float32, shape=(None, None, None)) self.aims_idx = tf.placeholder(tf.int32, shape=(None, None)) self.sessions_length = tf.placeholder(tf.int32, shape=(None)) self.reward = tf.placeholder(tf.float32, shape=(None)) if embed is None: self.embed = tf.get_variable( 'agn/embed', [num_items, num_embed_units], tf.float32, initializer=tf.truncated_normal_initializer(0, 1)) else: self.embed = tf.get_variable('agn/embed', dtype=tf.float32, initializer=embed) batch_size, encoder_length, rec_length = tf.shape( self.sessions_input)[0], tf.shape( self.sessions_input)[1], tf.shape(self.rec_lists)[2] encoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.sessions_length - 2, encoder_length), reverse=True, axis=1), [-1, encoder_length]) # [batch_size, length] self.sessions_target = tf.concat([ self.sessions_input[:, 1:], tf.ones([batch_size, 1], dtype=tf.int32) * PAD_ID ], 1) # [batch_size, length, embed_units] self.encoder_input = tf.nn.embedding_lookup(self.embed, self.sessions_input) # [batch_size, length, rec_length] self.aims = tf.one_hot(self.aims_idx, rec_length) if use_lstm: cell = MultiRNNCell( [LSTMCell(num_units) for _ in range(num_layers)]) else: cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # Training with tf.variable_scope("agn"): output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_items) self.encoder_output, self.encoder_state = dynamic_rnn( cell, self.encoder_input, self.sessions_length, dtype=tf.float32, scope="encoder") tmp_dim_1 = tf.tile( tf.reshape(tf.range(batch_size), [batch_size, 1, 1, 1]), [1, encoder_length, rec_length, 1]) tmp_dim_2 = tf.tile( tf.reshape(tf.range(encoder_length), [1, encoder_length, 1, 1]), [batch_size, 1, rec_length, 1]) # [batch_size, length, rec_length, 3] gather_idx = tf.concat( [tmp_dim_1, tmp_dim_2, tf.expand_dims(self.rec_lists, 3)], 3) # [batch_size, length, num_items], [batch_size*length] y_prob, local_loss, total_size = sampled_sequence_loss( self.encoder_output, self.sessions_target, encoder_mask) # Compute recommendation rank given rec_list # [batch_size, length, num_items] y_prob = tf.reshape(y_prob, [batch_size, encoder_length, num_items]) * \ tf.concat([tf.zeros([batch_size, encoder_length, 2], dtype=tf.float32), tf.ones([batch_size, encoder_length, num_items-2], dtype=tf.float32)], 2) # [batch_size, length, rec_len] ini_prob = tf.reshape(tf.gather_nd(y_prob, gather_idx), [batch_size, encoder_length, rec_length]) # [batch_size, length, rec_len] mul_prob = ini_prob * self.rec_mask # [batch_size, length, action_num] _, self.index = tf.nn.top_k(mul_prob, k=action_num) # [batch_size, length, metric_num] _, self.metric_index = tf.nn.top_k(mul_prob, k=(FLAGS['metric'].value + 1)) self.loss = tf.reduce_sum( tf.reshape(self.reward, [-1]) * local_loss) / total_size # Inference with tf.variable_scope("agn", reuse=True): # tf.get_variable_scope().reuse_variables() self.lstm_state = tf.placeholder(tf.float32, shape=(2, 2, None, num_units)) self.ini_state = (tf.contrib.rnn.LSTMStateTuple( self.lstm_state[0, 0, :, :], self.lstm_state[0, 1, :, :]), tf.contrib.rnn.LSTMStateTuple( self.lstm_state[1, 0, :, :], self.lstm_state[1, 1, :, :])) # [batch_size, length, num_units] self.encoder_output_predict, self.encoder_state_predict = dynamic_rnn( cell, self.encoder_input, self.sessions_length, initial_state=self.ini_state, dtype=tf.float32, scope="encoder") # [batch_size, num_units] self.final_output_predict = tf.reshape( self.encoder_output_predict[:, -1, :], [-1, num_units]) # [batch_size, num_items] self.rec_logits = output_fn(self.final_output_predict) # [batch_size, action_num] _, self.rec_index = tf.nn.top_k( self.rec_logits[:, len(_START_VOCAB):], action_num) self.rec_index += len(_START_VOCAB) def gumbel_max(inp, alpha, beta): # assert len(tf.shape(inp)) == 2 g = tf.random_uniform(tf.shape(inp), 0.0001, 0.9999) g = -tf.log(-tf.log(g)) inp_g = tf.nn.softmax( (tf.nn.log_softmax(inp / 1.0) + g * alpha) * beta) return inp_g # [batch_size, action_num] _, self.random_rec_index = tf.nn.top_k( gumbel_max(self.rec_logits[:, len(_START_VOCAB):], 1, 1), action_num) self.random_rec_index += len(_START_VOCAB) # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.params = tf.trainable_variables() gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = tf.train.AdamOptimizer( self.learning_rate).apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=100, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__( self, num_symbols, # 词汇表size num_embed_units, # 词嵌入size num_units, # RNN 每层单元数 num_layers, # RNN 层数 embed, # 词嵌入 entity_embed=None, # 实体+关系的嵌入 num_entities=0, # 实体+关系的总个数 num_trans_units=100, # 实体嵌入的维度 memory_units=100, learning_rate=0.0001, # 学习率 learning_rate_decay_factor=0.95, # 学习率衰退,并没有采用这种方式 max_gradient_norm=5.0, # num_samples=500, # 样本个数,sampled softmax max_length=60, mem_use=True, output_alignments=True, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # [batch_size, encoder_len] self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # [batch_size] self.responses = tf.placeholder( tf.string, (None, None), 'dec_inps') # [batch_size, decoder_len] self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # [batch_size] self.entities = tf.placeholder( tf.string, (None, None, None), 'entities') # [batch_size, triple_num, triple_len] self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # 没用到 self.triples = tf.placeholder( tf.string, (None, None, None, 3), 'triples') # [batch_size, triple_num, triple_len, 3] self.posts_triple = tf.placeholder( tf.int32, (None, None, 1), 'enc_triples') # [batch_size, encoder_len, 1] self.responses_triple = tf.placeholder( tf.string, (None, None, 3), 'dec_triples') # [batch_size, decoder_len, 3] self.match_triples = tf.placeholder( tf.int32, (None, None, None), 'match_triples') # [batch_size, decoder_len, triple_num] # 编码器batch_size,编码器encoder_len encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] # 知识图个数 triple_len = tf.shape(self.triples)[2] # 知识三元组个数 # 使用的知识三元组 one_hot_triples = tf.one_hot( self.match_triples, triple_len) # [batch_size, decoder_len, triple_num, triple_len] # 用 1 标注了哪个时间步产生的回复用了知识三元组 use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # [batch_size, decoder_len] # 词汇映射到index的hash table self.symbol2index = MutableHashTable( key_dtype=tf.string, # key张量的类型 value_dtype=tf.int64, # value张量的类型 default_value=UNK_ID, # 缺少key的默认值 shared_name= "in_table", # If non-empty, this table will be shared under the given name across multiple sessions name="in_table", # 操作名 checkpoint=True ) # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name. # index映射到词汇的hash table self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # 实体映射到index的hash table self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) # index映射到实体的hash table self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) self.posts_word_id = self.symbol2index.lookup( self.posts) # [batch_size, encoder_len] self.posts_entity_id = self.entity2index.lookup( self.posts) # [batch_size, encoder_len] self.responses_target = self.symbol2index.lookup( self.responses) # [batch_size, decoder_len] # 获得解码器的batch_size,decoder_len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] # 去掉responses_target的最后一列,给第一列加上GO_ID self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # [batch_size, decoder_len] # 得到response的mask self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # [batch_size, decoder_len] # 初始化词嵌入和实体嵌入,传入了参数就直接赋值,没有的话就随机初始化 if embed is None: self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # 实体嵌入不随着模型的训练而更新 self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) # 将实体嵌入传入一个全连接层 self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') # 添加['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T']这7个的嵌入 padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) # triples_embedding: [batch_size, triple_num, triple_len, 3*num_trans_units] 知识图三元组的嵌入 triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) # entities_word_embedding: [batch_size, triple_num*triple_len, num_embed_units] 知识图中用到的所有实体的嵌入 entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) # 分离知识图三元组的头、关系和尾 [batch_size, triple_num, triple_len, num_trans_units] head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 静态图注意力机制 with tf.variable_scope('graph_attention'): # 将头尾连接起来 [batch_size, triple_num, triple_len, 2*num_trans_units] head_tail = tf.concat([head, tail], axis=3) # 将头尾送入全连接层 [batch_size, triple_num, triple_len, num_trans_units] head_tail_transformed = tf.layers.dense(head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform') # 将关系送入全连接层 [batch_size, triple_num, triple_len, num_trans_units] relation_transformed = tf.layers.dense(relation, num_trans_units, name='relation_transform') # 求头尾和关系两个向量的内积,获得对三元组的注意力系数 e_weight = tf.reduce_sum( relation_transformed * head_tail_transformed, axis=3) # [batch_size, triple_num, triple_len] alpha_weight = tf.nn.softmax( e_weight) # [batch_size, triple_num, triple_len] # tf.expand_dims 使 alpha_weight 维度+1 [batch_size, triple_num, triple_len, 1] # 对第2个维度求和,由此产生静态图的向量表示 graph_embed = tf.reduce_sum( tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) # [batch_size, triple_num, 2*num_trans_units] """graph_embed_input 1、首先一维的range列表[0, 1, 2... encoder_batch_size个]转化成三维的[encoder_batch_size, 1, 1]的矩阵 [[[0]], [[1]], [[2]],...] 2、然后tf.tile将矩阵的第1维复制encoder_len遍,变成[encoder_batch_size, encoder_len, 1] [[[0],[0]...]],...] 3、与posts_triple: [batch_size, encoder_len, 1]在第2维上进行拼接,形成一个indices: [batch_size, encoder_len, 2]矩阵, indices矩阵: [ [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len], [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len], [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len] ,...batch_size ] 4、tf.gather_nd根据索引检索graph_embed: [batch_size, triple_num, 2*num_trans_units]再回填至indices矩阵 indices矩阵最后一个维度是2,例如有[0, 2],表示这个时间步第1个batch用了第2个图, 则找到这个知识图的静态图向量填入到indices矩阵的[0, 2]位置最后得到结果维度 [encoder_batch_size, encoder_len, 2*num_trans_units]表示每个时间步用的静态图向量 """ # graph_embed_input = tf.gather_nd(graph_embed, tf.concat( # [tf.tile(tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), # self.posts_triple], # axis=2)) # 将responses_triple转化成实体嵌入 [batch_size, decoder_len, 300],标识了response每个时间步用了哪个三元组的嵌入 # triple_embed_input = tf.reshape( # tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.responses_triple)), # [batch_size, decoder_len, 3 * num_trans_units]) post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # [batch_size, encoder_len, num_embed_units] response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id ) # [batch_size, decoder_len, num_embed_units] # post_word_input和graph_embed_input拼接构成编码器输入 [batch_size, encoder_len, num_embed_units+2*num_trans_units] # self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) # response_word_input和triple_embed_input拼接构成解码器输入 [batch_size, decoder_len, num_embed_units+3*num_trans_units] # self.decoder_input = tf.concat([response_word_input, triple_embed_input], axis=2) encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # rnn encoder # encoder_state: [num_layers, 2, batch_size, num_units] 编码器输出状态 LSTM GRU:[num_layers, batch_size, num_units] encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, post_word_input, self.posts_length, dtype=tf.float32, scope="encoder") # self.encoder_state_shape = tf.shape(encoder_state) ########记忆网络 ### response_encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) response_encoder_output, response_encoder_state = tf.nn.dynamic_rnn( response_encoder_cell, response_word_input, self.responses_length, dtype=tf.float32, scope="response_encoder") # graph_embed: [batch_size, triple_num, 2*num_trans_units] 静态图向量 # encoder_state: [num_layers, batch_size, num_units] with tf.variable_scope("post_memory_network"): # 将静态知识图转化成输入向量m post_input = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="post_weight_a") post_input = tf.tile( tf.reshape(post_input, (1, encoder_batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将静态知识库转化成输出向量c post_output = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="post_weight_c") post_output = tf.tile( tf.reshape(post_output, (1, encoder_batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将question转化成状态向量u encoder_hidden_state = tf.reshape( tf.concat(encoder_state, axis=0), (num_layers, encoder_batch_size, num_units)) post_state = tf.layers.dense(encoder_hidden_state, memory_units, use_bias=False, name="post_weight_b") post_state = tf.tile( tf.reshape(post_state, (num_layers, encoder_batch_size, 1, memory_units)), multiples=( 1, 1, triple_num, 1)) # [num_layers, batch_size, triple_num, memory_units] # 概率p post_p = tf.reshape( tf.nn.softmax(tf.reduce_sum(post_state * post_input, axis=3)), (num_layers, encoder_batch_size, triple_num, 1)) # [num_layers, batch_size, triple_num, 1] # 输出o post_o = tf.reduce_sum( post_output * post_p, axis=2) # [num_layers, batch_size, memory_units] post_xstar = tf.concat( [ tf.layers.dense(post_o, memory_units, use_bias=False, name="post_weight_r"), encoder_state ], axis=2) # [num_layers, batch_size, num_units+memory_units] with tf.variable_scope("response_memory_network"): # 将静态知识图转化成输入向量m response_input = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="response_weight_a") response_input = tf.tile( tf.reshape(response_input, (1, batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将静态知识库转化成输出向量c response_output = tf.layers.dense(graph_embed, memory_units, use_bias=False, name="response_weight_c") response_output = tf.tile( tf.reshape(response_output, (1, batch_size, triple_num, memory_units)), multiples=( num_layers, 1, 1, 1)) # [num_layers, batch_size, triple_num, memory_units] # 将question转化成状态向量u response_hidden_state = tf.reshape( tf.concat(response_encoder_state, axis=0), (num_layers, batch_size, num_units)) response_state = tf.layers.dense(response_hidden_state, memory_units, use_bias=False, name="response_weight_b") response_state = tf.tile( tf.reshape(response_state, (num_layers, batch_size, 1, memory_units)), multiples=( 1, 1, triple_num, 1)) # [num_layers, batch_size, triple_num, memory_units] # 概率p response_p = tf.reshape( tf.nn.softmax( tf.reduce_sum(response_state * response_input, axis=3)), (num_layers, batch_size, triple_num, 1)) # [num_layers, batch_size, triple_num, 1] # 输出o response_o = tf.reduce_sum( response_output * response_p, axis=2) # [num_layers, batch_size, memory_units] response_ystar = tf.concat( [ tf.layers.dense(response_o, memory_units, use_bias=False, name="response_weight_r"), response_encoder_state ], axis=2) # [num_layers, batch_size, num_units+memory_units] with tf.variable_scope("memory_network"): memory_hidden_state = tf.layers.dense(tf.concat( [post_xstar, response_ystar], axis=2), num_units, use_bias=False, activation=tf.tanh, name="output_weight") memory_hidden_state = tf.reshape( memory_hidden_state, (num_layers * batch_size, num_units)) # [num_layers, batch_size, num_units] memory_hidden_state = tuple( tf.split(memory_hidden_state, [batch_size] * num_layers, axis=0)) # self.memory_hidden_state_shape = tf.shape(memory_hidden_state) ######## ### output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss =\ output_projection_layer(num_units, num_symbols, num_samples) ########用于训练的decoder ### with tf.variable_scope('decoder'): attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 训练时处理每个时间步输出和下个时间步输入的函数 decoder_fn_train = attention_decoder_fn_train( memory_hidden_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, response_word_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2, 3]) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, name='ppx_loss') else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) ######## ### ########用于推导的decoder ### with tf.variable_scope('decoder', reuse=True): attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) decoder_fn_inference = \ attention_decoder_fn_inference(output_fn, memory_hidden_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, # imem: ([batch_size,triple_num*triple_len,num_embed_units], tf.reshape(triples_embedding, [encoder_batch_size, -1, 3*num_trans_units])), # [encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体词嵌入和三元组嵌入的元组 selector_fn=selector_fn) # decoder_distribution: [batch_size, decoder_len, num_symbols] # output_ids_ta: tensorarray: decoder_len [batch_size] self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] # decoder_len output_ids = tf.transpose( output_ids_ta.gather( tf.range(output_len))) # [batch_size, decoder_len] # 对output的值域行裁剪,因为存在负值表示用了实体词 word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) # [batch_size, decoder_len] # 计算的是实体词在entities中的实际位置 [batch_size, decoder_len] # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len # 2、tf.range(encoder_batch_size): [batch_size] # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在entities中的基地址 # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词在entities中的偏移量 # 5、entity_ids: [batch_size, decoder_len] 实体词在entities中的实际位置 entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) # 计算的是所用的实体词 [batch_size, decoder_len] # 1、entities: [batch_size, triple_num, triple_len] # 2、tf.reshape(self.entities, [-1]): [batch_size*triple_num*triple_len] # 3、tf.gather: [batch_size*decoder_len] # 4、entities: [batch_size, decoder_len] entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 将id转化为实际的词 # output_ids>0为bool张量,True的位置用words中该位置的词替换 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity( self.generation, name='generation') # [batch_size, decoder_len] ######## ### # 初始化训练过程 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) # 并没有使用衰退的学习率 self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) # 更新参数的次数 self.global_step = tf.Variable(0, trainable=False) # 要训练的参数 self.params = tf.global_variables() # 选择优化算法 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr # 根据 decoder_loss 计算 params 梯度 gradients = tf.gradients(self.decoder_loss, self.params) # 梯度裁剪 clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # 记录损失 tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) # 记录变量的训练情况 self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, embed, entity_embed=None, num_entities=0, num_trans_units=100, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=500, max_length=60, mem_use=True, output_alignments=True, use_lstm=False): # 输入数据占位定义 self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch self.entities = tf.placeholder(tf.string, (None, None, None), 'entities') # batch self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch self.triples = tf.placeholder(tf.string, (None, None, None, 3), 'triples') # batch self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch self.match_triples = tf.placeholder(tf.int32, (None, None, None), 'match_triples') # batch encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] triple_len = tf.shape(self.triples)[2] one_hot_triples = tf.one_hot(self.match_triples, triple_len) use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) # 构建词汇查询talbe (index to string, string to index) self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len self.posts_entity_id = self.entity2index.lookup( self.posts) # batch*len #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) self.responses_target = self.symbol2index.lookup( self.responses) # batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # 构建词嵌入 table (index to vector) if embed is None: # 随机初始化词嵌入 self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) else: # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec) self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) if entity_embed is None: # 随机初始化词嵌入 self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec) self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) self.entity_trans_transformed = tf.layers.dense( self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat( [padding_entity, self.entity_trans_transformed], axis=0) triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) # 知识融合层的静态注意力 with tf.variable_scope('graph_attention'): # 拼接head tail head_tail = tf.concat([head, tail], axis=3) # head tail合成一个向量 head_tail_transformed = tf.layers.dense(head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform') # relation 向量 relation_transformed = tf.layers.dense(relation, num_trans_units, name='relation_transform') # relation 和 head_tail 计算注意力权重 e_weight = tf.reduce_sum(relation_transformed * head_tail_transformed, axis=3) # 将注意力权重归一化 alpha_weight = tf.nn.softmax(e_weight) # 将权重和head_tail进行加权求和 graph_embed = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) graph_embed_input = tf.gather_nd( graph_embed, tf.concat([ tf.tile( tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple ], axis=2)) triple_embed_input = tf.reshape( tf.nn.embedding_lookup( self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * num_trans_units]) post_word_input = tf.nn.embedding_lookup( self.embed, self.posts_word_id) # batch*len*unit response_word_input = tf.nn.embedding_lookup( self.embed, self.responses_word_id) # batch*len*unit # 在输入语句中拼接注意力机制计算出来的图谱信息 self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) # 在输出语句中拼接所有图谱信息 self.decoder_input = tf.concat( [response_word_input, triple_embed_input], axis=2) # 编码器使用GRUCell, num_layers为网络层数 encoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # 解码器层使用GRUCell,num_layers为网络层数 decoder_cell = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # RNN编码器的包装 encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer( num_units, num_symbols, num_samples) # 解码器 with tf.variable_scope('decoder'): # 获取 attention 函数 attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 'luong', num_units) decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") if output_alignments: self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2, 3]) self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss( self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, name='ppx_loss') else: self.decoder_loss = sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) with tf.variable_scope('decoder', reuse=True): # 获取 attention 函数 attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use) # 'luong', num_units) decoder_fn_inference = attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, tf.reshape( triples_embedding, [encoder_batch_size, -1, 3 * num_trans_units])), selector_fn=selector_fn) self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder( decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(self.decoder_distribution)[1] output_ids = tf.transpose( output_ids_ta.gather(tf.range(output_len))) word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) # 生成用于输出的回复语句 self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity(self.generation, name='generation') # 训练参数初始化 self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.params = tf.global_variables() # 使用Adam优化器,计算高效、梯度平滑、参数调节简单 opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.lr = opt._lr gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)