class CVAE(object): def __init__(self, tfFLAGS, embed=None): self.vocab_size = tfFLAGS.vocab_size self.embed_size = tfFLAGS.embed_size self.num_units = tfFLAGS.num_units self.num_layers = tfFLAGS.num_layers self.beam_width = tfFLAGS.beam_width self.use_lstm = tfFLAGS.use_lstm self.attn_mode = tfFLAGS.attn_mode self.train_keep_prob = tfFLAGS.keep_prob self.max_decode_len = tfFLAGS.max_decode_len self.bi_encode = tfFLAGS.bi_encode self.recog_hidden_units = tfFLAGS.recog_hidden_units self.prior_hidden_units = tfFLAGS.prior_hidden_units self.z_dim = tfFLAGS.z_dim self.full_kl_step = tfFLAGS.full_kl_step self.global_step = tf.Variable(0, name="global_step", trainable=False) self.max_gradient_norm = 5.0 if tfFLAGS.opt == 'SGD': self.learning_rate = tf.Variable(float(tfFLAGS.learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * tfFLAGS.learning_rate_decay_factor) self.opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif tfFLAGS.opt == 'Momentum': self.opt = tf.train.MomentumOptimizer( learning_rate=tfFLAGS.learning_rate, momentum=tfFLAGS.momentum) else: self.learning_rate = tfFLAGS.learning_rate self.opt = tf.train.AdamOptimizer() self._make_input(embed) with tf.variable_scope("output_layer"): self.output_layer = Dense( self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(stddev=0.1)) with tf.variable_scope("encoders", initializer=tf.orthogonal_initializer()): self.enc_post_outputs, self.enc_post_state = self._build_encoder( scope='post_encoder', inputs=self.enc_post, sequence_length=self.post_len) self.enc_ref_outputs, self.enc_ref_state = self._build_encoder( scope='ref_encoder', inputs=self.enc_ref, sequence_length=self.ref_len) self.enc_response_outputs, self.enc_response_state = self._build_encoder( scope='resp_encoder', inputs=self.enc_response, sequence_length=self.response_len) self.post_state = self._get_representation_from_enc_state( self.enc_post_state) self.ref_state = self._get_representation_from_enc_state( self.enc_ref_state) self.response_state = self._get_representation_from_enc_state( self.enc_response_state) self.cond_embed = tf.concat([self.post_state, self.ref_state], axis=-1) with tf.variable_scope("RecognitionNetwork"): recog_input = tf.concat([self.cond_embed, self.response_state], axis=-1) recog_hidden = tf.layers.dense(inputs=recog_input, units=self.recog_hidden_units, activation=tf.nn.tanh) recog_mulogvar = tf.layers.dense(inputs=recog_hidden, units=self.z_dim * 2, activation=None) # recog_mulogvar = tf.layers.dense(inputs=recog_input, units=self.z_dim * 2, activation=None) recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=-1) with tf.variable_scope("PriorNetwork"): prior_input = self.cond_embed prior_hidden = tf.layers.dense(inputs=prior_input, units=self.prior_hidden_units, activation=tf.nn.tanh) prior_mulogvar = tf.layers.dense(inputs=prior_hidden, units=self.z_dim * 2, activation=None) prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=-1) with tf.variable_scope("GenerationNetwork"): latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar), name='latent_sample') gen_input = tf.concat([self.cond_embed, latent_sample], axis=-1) if self.use_lstm: self.dec_init_state = tuple([ tf.contrib.rnn.LSTMStateTuple( c=tf.layers.dense(inputs=gen_input, units=self.num_units, activation=None), h=tf.layers.dense(inputs=gen_input, units=self.num_units, activation=None)) for _ in range(self.num_layers) ]) print self.dec_init_state else: self.dec_init_state = tuple([ tf.layers.dense(inputs=gen_input, units=self.num_units, activation=None) for _ in range(self.num_layers) ]) kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) self.kl_weights = tf.minimum( tf.to_float(self.global_step) / self.full_kl_step, 1.0) self.kl_loss = self.kl_weights * self.avg_kld self._build_decoder() self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) for var in tf.trainable_variables(): print var def _make_input(self, embed): self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value=_UNK, shared_name="out_table", name="out_table", checkpoint=True) with tf.variable_scope("input"): self.post_string = tf.placeholder(tf.string, (None, None), 'post_string') self.ref_string = tf.placeholder(tf.string, (None, None), 'ref_string') self.response_string = tf.placeholder(tf.string, (None, None), 'response_string') self.post = self.symbol2index.lookup(self.post_string) self.post_len = tf.placeholder(tf.int32, (None, ), 'post_len') self.ref = self.symbol2index.lookup(self.ref_string) self.ref_len = tf.placeholder(tf.int32, (None, ), 'ref_len') self.response = self.symbol2index.lookup(self.response_string) self.response_len = tf.placeholder(tf.int32, (None, ), 'response_len') with tf.variable_scope("embedding") as scope: if embed is None: # initialize the embedding randomly self.emb_enc = self.emb_dec = tf.get_variable( "emb_share", [self.vocab_size, self.embed_size], dtype=tf.float32) else: # initialize the embedding by pre-trained word vectors print "share pre-trained embed" self.emb_enc = self.emb_dec = tf.get_variable( 'emb_share', dtype=tf.float32, initializer=embed) self.enc_post = tf.nn.embedding_lookup(self.emb_enc, self.post) self.enc_ref = tf.nn.embedding_lookup(self.emb_enc, self.ref) self.enc_response = tf.nn.embedding_lookup(self.emb_enc, self.response) self.batch_len = tf.shape(self.response)[1] self.batch_size = tf.shape(self.response)[0] self.response_input = tf.concat([ tf.ones((self.batch_size, 1), dtype=tf.int64) * GO_ID, tf.split(self.response, [self.batch_len - 1, 1], axis=1)[0] ], 1) self.dec_inp = tf.nn.embedding_lookup(self.emb_dec, self.response_input) self.keep_prob = tf.placeholder_with_default(1.0, ()) self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") def _build_encoder(self, scope, inputs, sequence_length): with tf.variable_scope(scope): if self.bi_encode: cell_fw, cell_bw = self._build_biencoder_cell() outputs, states = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=inputs, sequence_length=sequence_length, dtype=tf.float32) enc_outputs = tf.concat(outputs, axis=-1) enc_state = [] for i in range(self.num_layers): if self.use_lstm: encoder_state_c = tf.concat( [states[0][i].c, states[1][i].c], axis=-1) encoder_state_h = tf.concat( [states[0][i].h, states[1][i].h], axis=-1) enc_state.append( tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)) else: enc_state.append( tf.concat([states[0][i], states[1][i]], axis=-1)) enc_state = tuple(enc_state) return enc_outputs, enc_state else: enc_cell = self._build_encoder_cell() enc_outputs, enc_state = tf.nn.dynamic_rnn( cell=enc_cell, inputs=inputs, sequence_length=sequence_length, dtype=tf.float32) return enc_outputs, enc_state def _get_representation_from_enc_state(self, enc_state): if self.use_lstm: return tf.concat([state.h for state in enc_state], axis=-1) else: return tf.concat(enc_state, axis=-1) def _build_decoder(self): with tf.variable_scope("decode", initializer=tf.orthogonal_initializer()): dec_cell, init_state = self._build_decoder_cell( self.enc_post_outputs, self.post_len, self.dec_init_state) train_helper = tf.contrib.seq2seq.TrainingHelper( inputs=self.dec_inp, sequence_length=self.response_len) train_decoder = tf.contrib.seq2seq.BasicDecoder( cell=dec_cell, helper=train_helper, initial_state=init_state, output_layer=self.output_layer) train_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=train_decoder, maximum_iterations=self.max_decode_len, ) logits = train_output.rnn_output mask = tf.sequence_mask(self.response_len, self.batch_len, dtype=tf.float32) crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.response, logits=logits) crossent = tf.reduce_sum(crossent * mask) self.sen_loss = crossent / tf.to_float(self.batch_size) # ppl(loss avg) across each timestep, the same as : # self.loss = tf.contrib.seq2seq.sequence_loss(train_output.rnn_output, # self.response, # mask) self.ppl_loss = crossent / tf.reduce_sum(mask) # add kld: self.elbo = self.sen_loss + self.kl_loss # Calculate and clip gradients params = tf.trainable_variables() gradients = tf.gradients(self.elbo, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, self.max_gradient_norm) self.train_op = self.opt.apply_gradients( zip(clipped_gradients, params), global_step=self.global_step) self.train_out = self.index2symbol.lookup(tf.cast( train_output.sample_id, tf.int64), name='train_out') with tf.variable_scope("decode", reuse=True): dec_cell, init_state = self._build_decoder_cell( self.enc_post_outputs, self.post_len, self.dec_init_state) start_tokens = tf.tile(tf.constant([GO_ID], dtype=tf.int32), [self.batch_size]) end_token = EOS_ID infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.emb_dec, start_tokens, end_token) infer_decoder = tf.contrib.seq2seq.BasicDecoder( cell=dec_cell, helper=infer_helper, initial_state=init_state, output_layer=self.output_layer) infer_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=infer_decoder, maximum_iterations=self.max_decode_len, ) self.inference = self.index2symbol.lookup(tf.cast( infer_output.sample_id, tf.int64), name='inference') with tf.variable_scope("decode", reuse=True): dec_init_state = tf.contrib.seq2seq.tile_batch( self.dec_init_state, self.beam_width) enc_outputs = tf.contrib.seq2seq.tile_batch( self.enc_post_outputs, self.beam_width) post_len = tf.contrib.seq2seq.tile_batch(self.post_len, self.beam_width) dec_cell, init_state = self._build_decoder_cell( enc_outputs, post_len, dec_init_state, beam_width=self.beam_width) beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell, embedding=self.emb_dec, start_tokens=tf.ones_like(self.post_len) * GO_ID, end_token=EOS_ID, initial_state=init_state, beam_width=self.beam_width, output_layer=self.output_layer) beam_output, _, beam_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=beam_decoder, maximum_iterations=self.max_decode_len, ) self.beam_out = self.index2symbol.lookup(tf.cast( beam_output.predicted_ids, tf.int64), name='beam_out') def _build_encoder_cell(self): if self.use_lstm: cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(self.num_units), self.keep_prob) for _ in range(self.num_layers) ]) else: cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.GRUCell(self.num_units), self.keep_prob) for _ in range(self.num_layers) ]) return cell def _build_biencoder_cell(self): if self.use_lstm: cell_fw = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(self.num_units / 2), self.keep_prob) for _ in range(self.num_layers) ]) cell_bw = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(self.num_units / 2), self.keep_prob) for _ in range(self.num_layers) ]) else: cell_fw = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.GRUCell(self.num_units / 2), self.keep_prob) for _ in range(self.num_layers) ]) cell_bw = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.GRUCell(self.num_units / 2), self.keep_prob) for _ in range(self.num_layers) ]) return cell_fw, cell_bw def _build_decoder_cell(self, memory, memory_len, encode_state, beam_width=1): if self.use_lstm: cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(self.num_units), self.keep_prob) for _ in range(self.num_layers) ]) else: cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.GRUCell(self.num_units), self.keep_prob) for _ in range(self.num_layers) ]) if self.attn_mode == 'Luong': attention_mechanism = tf.contrib.seq2seq.LuongAttention( num_units=self.num_units, memory=memory, memory_sequence_length=memory_len, scale=True) elif self.attn_mode == 'Bahdanau': attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( num_units=self.num_units, memory=memory, memory_sequence_length=memory_len, scale=True) else: return cell, encode_state attn_cell = tf.contrib.seq2seq.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=self.num_units, ) return attn_cell, attn_cell.zero_state( self.batch_size * beam_width, tf.float32).clone(cell_state=encode_state) def initialize(self, sess, vocab): op_in = self.symbol2index.insert( constant_op.constant(vocab), constant_op.constant(range(len(vocab)), dtype=tf.int64)) op_out = self.index2symbol.insert( constant_op.constant(range(len(vocab)), dtype=tf.int64), constant_op.constant(vocab)) sess.run(tf.global_variables_initializer()) sess.run([op_in, op_out]) def step(self, sess, data, is_train=False): input_feed = { self.post_string: data['post'], self.post_len: data['post_len'], self.ref_string: data['ref'], self.ref_len: data['ref_len'], self.response_string: data['response'], self.response_len: data['response_len'], self.use_prior: is_train, } if is_train: output_feed = [ self.train_op, self.ppl_loss, self.elbo, self.sen_loss, self.kl_loss, self.avg_kld, self.kl_weights, # self.post_string, # self.response_string, # self.train_out, # self.inference, # self.beam_out, ] input_feed[self.keep_prob] = self.train_keep_prob else: output_feed = [ self.ppl_loss, self.elbo, self.sen_loss, self.kl_loss, self.avg_kld, self.kl_weights, # self.post_string, # self.response_string, # self.train_out, # self.inference, # self.beam_out, ] return sess.run(output_feed, input_feed)
class Model(object): def __init__(self, word_embed, entity_embed, vocab_size=30000, num_embed_units=300, num_units=512, num_layers=2, num_entities=0, num_trans_units=100, max_length=60, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=500, output_alignments=True): # initialize params self.vocab_size = vocab_size self.num_embed_units = num_embed_units self.num_units = num_units self.num_layers = num_layers self.num_entities = num_entities self.num_trans_units = num_trans_units self.learning_rate = learning_rate self.max_gradient_norm = max_gradient_norm self.num_samples = num_samples self.max_length = max_length self.output_alignments = output_alignments # build the embedding table (index to vector) if word_embed is None: # initialize the embedding randomly self.word_embed = tf.get_variable( 'word_embed', [self.vocab_size, self.num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.word_embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=word_embed) if entity_embed is None: # initialize the embedding randomly self.entity_trans = tf.get_variable( 'entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) else: # initialize the embedding by pre-trained trans vectors self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) # initialize inputs and outputs self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch self.entities = tf.placeholder(tf.string, (None, None, None), 'entities') # batch self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch self.triples = tf.placeholder(tf.string, (None, None, None, 3), 'triples') # batch self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch self.match_triples = tf.placeholder(tf.int32, (None, None, None), 'match_triples') # batch self._init_vocabs() # build the vocab table (string to index) self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len self.posts_entity_id = self.entity2index.lookup( self.posts) # batch*len self.responses_target = self.symbol2index.lookup( self.responses) # batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_word_id = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build entity embeddings entity_trans_transformed = tf.layers.dense(self.entity_trans, self.num_trans_units, activation=tf.tanh, name='trans_transformation') padding_entity = tf.get_variable('entity_padding_embed', [7, self.num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) self.entity_embed = tf.concat( [padding_entity, entity_trans_transformed], axis=0) # get knowledge graph embedding, knowledge triple embedding self.triples_embedding, self.entities_word_embedding, self.graph_embedding = self._build_kg_embedding( ) # build knowledge graph graph_embed_input, triple_embed_input = self._build_kg_graph() # build encoder encoder_output, encoder_state = self._build_encoder(graph_embed_input) # build decoder self._build_decoder(encoder_output, encoder_state, triple_embed_input) # initialize training process self.global_step = tf.Variable(0, trainable=False) self.params = tf.global_variables() gradients = tf.gradients(self.decoder_loss, self.params) self.clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, self.max_gradient_norm) optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) self.update = optimizer.apply_gradients(zip(self.clipped_gradients, self.params), global_step=self.global_step) tf.summary.scalar('decoder_loss', self.decoder_loss) for each in tf.trainable_variables(): tf.summary.histogram(each.name, each) self.merged_summary_op = tf.summary.merge_all() def _init_vocabs(self): self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) self.entity2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=NONE_ID, shared_name="entity_in_table", name="entity_in_table", checkpoint=True) self.index2entity = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_NONE', shared_name="entity_out_table", name="entity_out_table", checkpoint=True) def _build_kg_embedding(self): encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) triple_num = tf.shape(self.triples)[1] triples_embedding = tf.reshape( tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * self.num_trans_units]) entities_word_embedding = tf.reshape( tf.nn.embedding_lookup(self.word_embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, self.num_embed_units]) head, relation, tail = tf.split(triples_embedding, [self.num_trans_units] * 3, axis=3) with tf.variable_scope('graph_attention', reuse=tf.AUTO_REUSE): head_tail = tf.concat([head, tail], axis=3) head_tail_transformed = tf.layers.dense(head_tail, self.num_trans_units, activation=tf.tanh, name='head_tail_transform') relation_transformed = tf.layers.dense(relation, self.num_trans_units, name='relation_transform') e_weight = tf.reduce_sum(relation_transformed * head_tail_transformed, axis=3) alpha_weight = tf.nn.softmax(e_weight) graph_embedding = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) return triples_embedding, entities_word_embedding, graph_embedding def _build_kg_graph(self): encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] # knowledge graph vectors graph_embed_input = tf.gather_nd( self.graph_embedding, tf.concat([ tf.tile( tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple ], axis=2)) # knowledge triple vectors triple_embed_input = tf.reshape( tf.nn.embedding_lookup( self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * self.num_trans_units]) return graph_embed_input, triple_embed_input def _build_encoder(self, graph_embed_input): post_word_input = tf.nn.embedding_lookup( self.word_embed, self.posts_word_id) # batch*len*unit encoder_cell = MultiRNNCell( [GRUCell(self.num_units) for _ in range(self.num_layers)]) # encoder input: e(x_t) = [w(x_t); g_i] encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) encoder_output, encoder_state = dynamic_rnn(encoder_cell, encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # shape:[batch_size, max_time, cell.output_size] return encoder_output, encoder_state def _build_decoder(self, encoder_output, encoder_state, triple_embed_input): # decoder input: e(y_t) = [w(y_t); k_j] encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) response_word_input = tf.nn.embedding_lookup( self.word_embed, self.responses_word_id) # batch*len*unit decoder_input = tf.concat([response_word_input, triple_embed_input], axis=2) print("decoder_input:", decoder_input.shape) # define cell decoder_cell = MultiRNNCell( [GRUCell(self.num_units) for _ in range(self.num_layers)]) # get loss functions sequence_loss, total_loss = loss_computation( self.vocab_size, num_samples=self.num_samples) # decoder training process with tf.variable_scope('decoder'): # prepare attention attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', self.num_units, scope_name="decoder", imem=(self.graph_embedding, self.triples_embedding), output_alignments=self.output_alignments) print("graph_embedding:", self.graph_embedding.shape) print("triples_embedding:", self.triples_embedding.shape) decoder_fn_train = attention_decoder_fn_train( encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, output_alignments=self.output_alignments, max_length=tf.reduce_max(self.responses_length)) # train decoder decoder_output, _, decoder_context_state = dynamic_rnn_decoder( decoder_cell, decoder_fn_train, decoder_input, self.responses_length, scope="decoder_rnn") output_fn, selector_fn = output_projection( self.vocab_size, scope_name="decoder_rnn") output_logits = output_fn(decoder_output) selector_logits = selector_fn(decoder_output) print("decoder_output:", decoder_output.shape) # shape: [batch, seq, num_units] print("output_logits:", output_logits.shape) print("selector_fn:", selector_logits.name) triple_len = tf.shape(self.triples)[2] one_hot_triples = tf.one_hot(self.match_triples, triple_len) use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) alignments = tf.transpose(decoder_context_state.stack(), perm=[1, 0, 2, 3]) self.decoder_loss, self.ppx_loss, self.sentence_ppx \ = total_loss(output_logits, selector_logits, self.responses_target, self.decoder_mask, alignments, use_triples, one_hot_triples) self.sentence_ppx = tf.identity(self.sentence_ppx, name="ppx_loss") # decoder inference process with tf.variable_scope('decoder', reuse=True): # prepare attention attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = prepare_attention(encoder_output, 'bahdanau', self.num_units, scope_name="decoder", imem=(self.graph_embedding, self.triples_embedding), output_alignments=self.output_alignments, reuse=True) output_fn, selector_fn = output_projection(self.vocab_size, scope_name=None, reuse=True) decoder_fn_inference \ = attention_decoder_fn_inference(output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.word_embed, GO_ID, EOS_ID, self.max_length, self.vocab_size, imem=(self.entities_word_embedding, tf.reshape(self.triples_embedding, [encoder_batch_size, -1, 3 * self.num_trans_units])), selector_fn=selector_fn) # get decoder output decoder_distribution, _, infer_context_state \ = dynamic_rnn_decoder(decoder_cell, decoder_fn_inference, scope="decoder_rnn") output_len = tf.shape(decoder_distribution)[1] output_ids = tf.transpose( infer_context_state.gather(tf.range(output_len))) word_ids = tf.cast( tf.clip_by_value(output_ids, 0, self.vocab_size), tf.int64) entity_ids = tf.reshape( tf.clip_by_value(-output_ids, 0, self.vocab_size) + tf.reshape( tf.range(encoder_batch_size) * tf.shape(self.entities_word_embedding)[1], [-1, 1]), [-1]) entities = tf.reshape( tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) words = self.index2symbol.lookup(word_ids) self.generation = tf.where(output_ids > 0, words, entities) self.generation = tf.identity(self.generation, name='generation') def set_vocabs(self, session, vocab, entity_vocab, relation_vocab): op_in = self.symbol2index.insert( constant_op.constant(vocab), constant_op.constant(list(range(self.vocab_size)), dtype=tf.int64)) session.run(op_in) op_out = self.index2symbol.insert( constant_op.constant(list(range(self.vocab_size)), dtype=tf.int64), constant_op.constant(vocab)) session.run(op_out) op_in = self.entity2index.insert( constant_op.constant(entity_vocab + relation_vocab), constant_op.constant(list( range(len(entity_vocab) + len(relation_vocab))), dtype=tf.int64)) session.run(op_in) op_out = self.index2entity.insert( constant_op.constant(list( range(len(entity_vocab) + len(relation_vocab))), dtype=tf.int64), constant_op.constant(entity_vocab + relation_vocab)) session.run(op_out) return session def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape().as_list())) def step_train(self, session, data, forward_only=False, summary=False): input_feed = { self.posts: data['posts'], self.posts_length: data['posts_length'], self.responses: data['responses'], self.responses_length: data['responses_length'], self.triples: data['triples'], self.posts_triple: data['posts_triple'], self.responses_triple: data['responses_triple'], self.match_triples: data['match_triples'] } if forward_only: output_feed = [self.sentence_ppx] else: output_feed = [self.sentence_ppx, self.decoder_loss, self.update] if summary: output_feed.append(self.merged_summary_op) return session.run(output_feed, input_feed)