def __call__(self, name, max_len, reuse=False): component_scope = self._variable_scope + name with tf.variable_scope(component_scope) as scope: if reuse: scope.reuse_variables() max_len = max_len + 2 if self._word_delimiters else max_len word_lens_source = tf.placeholder(dtype=tf.int32, shape=[None], name='source/word_lens') word_lens_target = tf.placeholder(dtype=tf.int32, shape=[None], name='source/word_target') chars_p_s = tf.placeholder(dtype=tf.int32, shape=[max_len, None], name='source/char_sequence%i' % max_len) chars_p_t = tf.placeholder(dtype=tf.int32, shape=[max_len, None], name='target/char_sequence%i' % max_len) chars_s = tf.nn.embedding_lookup(self.Wchar_s, chars_p_s) chars_s = tf.transpose(chars_s, [1, 0, 2]) chars_t = tf.nn.embedding_lookup(self.Wchar_t, chars_p_t) chars_t = tf.transpose(chars_t, [1, 0, 2]) chars = tf.concat([chars_s, chars_t], 2) enc_output_infer, enc_state_infer = tf.nn.dynamic_rnn(self.char_rnn_cell_infer, chars, dtype=tf.float32, sequence_length=tf.cast(word_lens_source, dtype=tf.int64), swap_memory=True, scope='encoder') attn_keys, attn_values, attn_score_fn, attn_construct_fn = attention_decoder_fn.prepare_attention(enc_output_infer, 'luong', self.char_rnn_cell_infer.output_size) dec_fn_inf = attention_decoder_fn.attention_decoder_fn_train( enc_state_infer, attn_keys, attn_values, attn_score_fn, attn_construct_fn) outputs_infer, _, _ = seq2seq.dynamic_rnn_decoder(self.char_rnn_cell_infer, dec_fn_inf, inputs=chars, sequence_length=word_lens_target, swap_memory=True, scope='decoder') scope.reuse_variables() enc_output, enc_state = tf.nn.dynamic_rnn(self.char_rnn_cell_train, chars, dtype=tf.float32, sequence_length=tf.cast(word_lens_source, dtype=tf.int64), swap_memory=True, scope='encoder') attn_keys, attn_values, attn_score_fn, attn_construct_fn = attention_decoder_fn.prepare_attention( enc_output, 'luong', self.char_rnn_cell_infer.output_size) dec_fn = attention_decoder_fn.attention_decoder_fn_train( enc_state, attn_keys, attn_values, attn_score_fn, attn_construct_fn) outputs, _, _ = seq2seq.dynamic_rnn_decoder(self.char_rnn_cell_train, dec_fn, inputs=chars, sequence_length=word_lens_target, swap_memory=True, scope='decoder') output = _get_last_state_dyn(self.max_norm, word_lens_target, outputs) output_infer = _get_last_state_dyn(self.max_norm, word_lens_target, outputs_infer) inputs = [ chars_p_s, word_lens_source, chars_p_t, word_lens_target ] char_feature_extractor1 = CharLevelInputExtraction(self._char_vocab_source, max_len, component_scope + 'source/') char_feature_extractor2 = CharLevelInputExtraction(self._char_vocab_target, max_len, component_scope + 'target/') char_feature_extractor = CombineFeatureExtraction(char_feature_extractor1, char_feature_extractor2) return Component(inputs, output, output_infer=output_infer, feature_extractor=char_feature_extractor, name='c_rnn_joint')
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, learning_rate=0.1, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=True): self.posts_1 = tf.placeholder(tf.string, shape=(None, None)) self.posts_2 = tf.placeholder(tf.string, shape=(None, None)) self.posts_3 = tf.placeholder(tf.string, shape=(None, None)) self.posts_4 = tf.placeholder(tf.string, shape=(None, None)) self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_mask_1 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_2 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_3 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_4 = tf.placeholder(tf.float32, shape=(None, None, None)) self.posts_length_1 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_2 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_3 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_4 = tf.placeholder(tf.int32, shape=(None)) self.responses = tf.placeholder(tf.string, shape=(None, None)) self.responses_length = tf.placeholder(tf.int32, shape=(None)) self.epoch = tf.Variable(0, trainable=False, name='epoch') self.epoch_add_op = self.epoch.assign(self.epoch + 1) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.'] * num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input_1 = self.symbol2index.lookup(self.posts_1) self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup( self.posts_2) self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup( self.posts_3) self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup( self.posts_4) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.posts_input_2 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1], 1)[0] ], 1) self.posts_input_3 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1], 1)[0] ], 1) self.posts_input_4 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1], 1)[0] ], 1) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) self.encoder_2_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_2 - 1, tf.shape(self.posts_2)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_2)[1]]) self.encoder_3_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_3 - 1, tf.shape(self.posts_3)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_3)[1]]) self.encoder_4_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_4 - 1, tf.shape(self.posts_4)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_4)[1]]) self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) if embed is None: self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input_1 = tf.nn.embedding_lookup(self.embed, self.posts_input_1) self.encoder_input_2 = tf.nn.embedding_lookup(self.embed, self.posts_input_2) self.encoder_input_3 = tf.nn.embedding_lookup(self.embed, self.posts_input_3) self.encoder_input_4 = tf.nn.embedding_lookup(self.embed, self.posts_input_4) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) entity_embedding_1 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_1)), [ batch_size, tf.shape(self.entity_1)[1], tf.shape(self.entity_1)[2], 3 * num_embed_units ]) entity_embedding_2 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_2)), [ batch_size, tf.shape(self.entity_2)[1], tf.shape(self.entity_2)[2], 3 * num_embed_units ]) entity_embedding_3 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_3)), [ batch_size, tf.shape(self.entity_3)[1], tf.shape(self.entity_3)[2], 3 * num_embed_units ]) entity_embedding_4 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_4)), [ batch_size, tf.shape(self.entity_4)[1], tf.shape(self.entity_4)[2], 3 * num_embed_units ]) head_1, relation_1, tail_1 = tf.split(entity_embedding_1, [num_embed_units] * 3, axis=3) head_2, relation_2, tail_2 = tf.split(entity_embedding_2, [num_embed_units] * 3, axis=3) head_3, relation_3, tail_3 = tf.split(entity_embedding_3, [num_embed_units] * 3, axis=3) head_4, relation_4, tail_4 = tf.split(entity_embedding_4, [num_embed_units] * 3, axis=3) with tf.variable_scope('graph_attention'): #[batch_size, max_reponse_length, max_triple_num, 2*embed_units] head_tail_1 = tf.concat([head_1, tail_1], axis=3) #[batch_size, max_reponse_length, max_triple_num, embed_units] head_tail_transformed_1 = tf.layers.dense( head_tail_1, num_embed_units, activation=tf.tanh, name='head_tail_transform') #[batch_size, max_reponse_length, max_triple_num, embed_units] relation_transformed_1 = tf.layers.dense(relation_1, num_embed_units, name='relation_transform') #[batch_size, max_reponse_length, max_triple_num] e_weight_1 = tf.reduce_sum(relation_transformed_1 * head_tail_transformed_1, axis=3) #[batch_size, max_reponse_length, max_triple_num] alpha_weight_1 = tf.nn.softmax(e_weight_1) #[batch_size, max_reponse_length, embed_units] graph_embed_1 = tf.reduce_sum( tf.expand_dims(alpha_weight_1, 3) * (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_2 = tf.concat([head_2, tail_2], axis=3) head_tail_transformed_2 = tf.layers.dense( head_tail_2, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_2 = tf.layers.dense(relation_2, num_embed_units, name='relation_transform') e_weight_2 = tf.reduce_sum(relation_transformed_2 * head_tail_transformed_2, axis=3) alpha_weight_2 = tf.nn.softmax(e_weight_2) graph_embed_2 = tf.reduce_sum( tf.expand_dims(alpha_weight_2, 3) * (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_3 = tf.concat([head_3, tail_3], axis=3) head_tail_transformed_3 = tf.layers.dense( head_tail_3, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_3 = tf.layers.dense(relation_3, num_embed_units, name='relation_transform') e_weight_3 = tf.reduce_sum(relation_transformed_3 * head_tail_transformed_3, axis=3) alpha_weight_3 = tf.nn.softmax(e_weight_3) graph_embed_3 = tf.reduce_sum( tf.expand_dims(alpha_weight_3, 3) * (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_4 = tf.concat([head_4, tail_4], axis=3) head_tail_transformed_4 = tf.layers.dense( head_tail_4, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_4 = tf.layers.dense(relation_4, num_embed_units, name='relation_transform') e_weight_4 = tf.reduce_sum(relation_transformed_4 * head_tail_transformed_4, axis=3) alpha_weight_4 = tf.nn.softmax(e_weight_4) graph_embed_4 = tf.reduce_sum( tf.expand_dims(alpha_weight_4, 3) * (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4), axis=2) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) encoder_output_1, encoder_state_1 = dynamic_rnn(cell, self.encoder_input_1, self.posts_length_1, dtype=tf.float32, scope="encoder") attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \ = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units) decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_1, attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1, max_length=tf.reduce_max(self.posts_length_2)) encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder( cell, decoder_fn_train_1, self.encoder_input_2, self.posts_length_2, scope="decoder") self.alignments_2 = tf.transpose(alignments_ta_2.stack(), perm=[1, 0, 2]) self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2, self.posts_2_target, self.encoder_2_mask) with variable_scope.variable_scope('', reuse=True): attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \ = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units) decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_2, attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2, max_length=tf.reduce_max(self.posts_length_3)) encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder( cell, decoder_fn_train_2, self.encoder_input_3, self.posts_length_3, scope="decoder") self.alignments_3 = tf.transpose(alignments_ta_3.stack(), perm=[1, 0, 2]) self.decoder_loss_3 = sampled_sequence_loss( encoder_output_3, self.posts_3_target, self.encoder_3_mask) attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \ = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units) decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_3, attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3, max_length=tf.reduce_max(self.posts_length_4)) encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder( cell, decoder_fn_train_3, self.encoder_input_4, self.posts_length_4, scope="decoder") self.alignments_4 = tf.transpose(alignments_ta_4.stack(), perm=[1, 0, 2]) self.decoder_loss_4 = sampled_sequence_loss( encoder_output_4, self.posts_4_target, self.encoder_4_mask) attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units) if is_train: with variable_scope.variable_scope('', reuse=True): decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2]) self.decoder_loss = sampled_sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask) self.params = tf.trainable_variables() self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) #opt = tf.train.GradientDescentOptimizer(self.learning_rate) opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) gradients = tf.gradients( self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 + self.decoder_loss_4, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: with variable_scope.variable_scope('', reuse=True): decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder") output_len = tf.shape(self.decoder_distribution)[1] self.alignments = tf.transpose( alignments_ta.gather(tf.range(output_len)), [1, 0, 2]) self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index, name="generation") self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=10, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def test_dynamic_rnn_decoder_time_major(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)) as varscope: # Define inputs/outputs to model batch_size = 2 encoder_embedding_size = 3 decoder_embedding_size = 4 encoder_hidden_size = 5 decoder_hidden_size = encoder_hidden_size input_sequence_length = 6 decoder_sequence_length = 7 num_decoder_symbols = 20 start_of_sequence_id = end_of_sequence_id = 1 decoder_embeddings = variable_scope.get_variable( "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size], initializer=init_ops.random_normal_initializer(stddev=0.1)) inputs = constant_op.constant( 0.5, shape=[input_sequence_length, batch_size, encoder_embedding_size]) decoder_inputs = constant_op.constant( 0.4, shape=[decoder_sequence_length, batch_size, decoder_embedding_size]) decoder_length = constant_op.constant( decoder_sequence_length, dtype=dtypes.int32, shape=[batch_size,]) with variable_scope.variable_scope("rnn") as scope: # setting up weights for computing the final output output_fn = lambda x: layers.linear(x, num_decoder_symbols, scope=scope) # Define model encoder_outputs, encoder_state = rnn.dynamic_rnn( cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), inputs=inputs, dtype=dtypes.float32, time_major=True, scope=scope) with variable_scope.variable_scope("decoder") as scope: # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size) decoder_fn_train = Seq2SeqTest._decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_train( encoder_state=encoder_state)) (decoder_outputs_train, decoder_state_train, decoder_context_state_train) = (seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_inputs, sequence_length=decoder_length, time_major=True, scope=scope)) decoder_outputs_train = output_fn(decoder_outputs_train) # Setup variable reuse scope.reuse_variables() # Inference decoder decoder_fn_inference = Seq2SeqTest._decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_state, embeddings=decoder_embeddings, start_of_sequence_id=start_of_sequence_id, end_of_sequence_id=end_of_sequence_id, #TODO: find out why it goes to +1 maximum_length=decoder_sequence_length - 1, num_decoder_symbols=num_decoder_symbols, dtype=dtypes.int32)) (decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference) = (seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) # Run model variables.global_variables_initializer().run() (decoder_outputs_train_res, decoder_state_train_res, decoder_context_state_train_res) = sess.run([ decoder_outputs_train, decoder_state_train, decoder_context_state_train ]) (decoder_outputs_inference_res, decoder_state_inference_res, decoder_context_state_inference_res) = sess.run([ decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference ]) # Assert outputs self.assertEqual((decoder_sequence_length, batch_size, num_decoder_symbols), decoder_outputs_train_res.shape) self.assertEqual((batch_size, num_decoder_symbols), decoder_outputs_inference_res.shape[1:3]) self.assertEqual(decoder_sequence_length, decoder_context_state_inference_res) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_train_res.shape) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_inference_res.shape) self.assertEqual(decoder_sequence_length, decoder_context_state_train_res) # The dynamic decoder might end earlier than `maximal_length` # under inference self.assertGreaterEqual(decoder_sequence_length, decoder_state_inference_res.shape[0])
def test_attention(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): # Define inputs/outputs to model batch_size = 2 encoder_embedding_size = 3 decoder_embedding_size = 4 encoder_hidden_size = 5 decoder_hidden_size = encoder_hidden_size input_sequence_length = 6 decoder_sequence_length = 7 num_decoder_symbols = 20 start_of_sequence_id = end_of_sequence_id = 1 decoder_embeddings = variable_scope.get_variable( "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size], initializer=init_ops.random_normal_initializer(stddev=0.1)) inputs = constant_op.constant( 0.5, shape=[input_sequence_length, batch_size, encoder_embedding_size]) decoder_inputs = constant_op.constant( 0.4, shape=[decoder_sequence_length, batch_size, decoder_embedding_size]) decoder_length = constant_op.constant( decoder_sequence_length, dtype=dtypes.int32, shape=[batch_size,]) # attention attention_option = "luong" # can be "bahdanau" with variable_scope.variable_scope("rnn") as scope: # Define model encoder_outputs, encoder_state = rnn.dynamic_rnn( cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), inputs=inputs, dtype=dtypes.float32, time_major=True, scope=scope) # attention_states: size [batch_size, max_time, num_units] attention_states = array_ops.transpose(encoder_outputs, [1, 0, 2]) with variable_scope.variable_scope("decoder") as scope: # Prepare attention (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = (attention_decoder_fn.prepare_attention( attention_states, attention_option, decoder_hidden_size)) decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state=encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn) # setting up weights for computing the final output def create_output_fn(): def output_fn(x): return layers.linear(x, num_decoder_symbols, scope=scope) return output_fn output_fn = create_output_fn() # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size) (decoder_outputs_train, decoder_state_train, _) = ( seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_inputs, sequence_length=decoder_length, time_major=True, scope=scope)) decoder_outputs_train = output_fn(decoder_outputs_train) # Setup variable reuse scope.reuse_variables() # Inference decoder decoder_fn_inference = ( attention_decoder_fn.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=decoder_embeddings, start_of_sequence_id=start_of_sequence_id, end_of_sequence_id=end_of_sequence_id, maximum_length=decoder_sequence_length - 1, num_decoder_symbols=num_decoder_symbols, dtype=dtypes.int32)) (decoder_outputs_inference, decoder_state_inference, _) = ( seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) # Run model variables.global_variables_initializer().run() (decoder_outputs_train_res, decoder_state_train_res) = sess.run( [decoder_outputs_train, decoder_state_train]) (decoder_outputs_inference_res, decoder_state_inference_res) = sess.run( [decoder_outputs_inference, decoder_state_inference]) # Assert outputs self.assertEqual((decoder_sequence_length, batch_size, num_decoder_symbols), decoder_outputs_train_res.shape) self.assertEqual((batch_size, num_decoder_symbols), decoder_outputs_inference_res.shape[1:3]) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_train_res.shape) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_inference_res.shape) # The dynamic decoder might end earlier than `maximal_length` # under inference self.assertGreaterEqual(decoder_sequence_length, decoder_state_inference_res.shape[0])
def test_dynamic_rnn_decoder_time_major(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer( 0.5)) as varscope: # Define inputs/outputs to model batch_size = 2 encoder_embedding_size = 3 decoder_embedding_size = 4 encoder_hidden_size = 5 decoder_hidden_size = encoder_hidden_size input_sequence_length = 6 decoder_sequence_length = 7 num_decoder_symbols = 20 start_of_sequence_id = end_of_sequence_id = 1 decoder_embeddings = variable_scope.get_variable( "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size], initializer=init_ops.random_normal_initializer(stddev=0.1)) inputs = constant_op.constant(0.5, shape=[ input_sequence_length, batch_size, encoder_embedding_size ]) decoder_inputs = constant_op.constant( 0.4, shape=[ decoder_sequence_length, batch_size, decoder_embedding_size ]) decoder_length = constant_op.constant(decoder_sequence_length, dtype=dtypes.int32, shape=[ batch_size, ]) with variable_scope.variable_scope("rnn") as scope: # setting up weights for computing the final output output_fn = lambda x: layers.linear( x, num_decoder_symbols, scope=scope) # Define model encoder_outputs, encoder_state = rnn.dynamic_rnn( cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), inputs=inputs, dtype=dtypes.float32, time_major=True, scope=scope) with variable_scope.variable_scope("decoder") as scope: # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell( decoder_hidden_size) decoder_fn_train = Seq2SeqTest._decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_train( encoder_state=encoder_state)) (decoder_outputs_train, decoder_state_train, decoder_context_state_train) = ( seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_inputs, sequence_length=decoder_length, time_major=True, scope=scope)) decoder_outputs_train = output_fn(decoder_outputs_train) # Setup variable reuse scope.reuse_variables() # Inference decoder decoder_fn_inference = Seq2SeqTest._decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_state, embeddings=decoder_embeddings, start_of_sequence_id=start_of_sequence_id, end_of_sequence_id=end_of_sequence_id, #TODO: find out why it goes to +1 maximum_length=decoder_sequence_length - 1, num_decoder_symbols=num_decoder_symbols, dtype=dtypes.int32)) (decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) # Run model variables.global_variables_initializer().run() (decoder_outputs_train_res, decoder_state_train_res, decoder_context_state_train_res) = sess.run([ decoder_outputs_train, decoder_state_train, decoder_context_state_train ]) (decoder_outputs_inference_res, decoder_state_inference_res, decoder_context_state_inference_res) = sess.run([ decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference ]) # Assert outputs self.assertEqual( (decoder_sequence_length, batch_size, num_decoder_symbols), decoder_outputs_train_res.shape) self.assertEqual((batch_size, num_decoder_symbols), decoder_outputs_inference_res.shape[1:3]) self.assertEqual(decoder_sequence_length, decoder_context_state_inference_res) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_train_res.shape) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_inference_res.shape) self.assertEqual(decoder_sequence_length, decoder_context_state_train_res) # The dynamic decoder might end earlier than `maximal_length` # under inference self.assertGreaterEqual(decoder_sequence_length, decoder_state_inference_res.shape[0])
def test_attention(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): # Define inputs/outputs to model batch_size = 2 encoder_embedding_size = 3 decoder_embedding_size = 4 encoder_hidden_size = 5 decoder_hidden_size = encoder_hidden_size input_sequence_length = 6 decoder_sequence_length = 7 num_decoder_symbols = 20 start_of_sequence_id = end_of_sequence_id = 1 decoder_embeddings = variable_scope.get_variable( "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size], initializer=init_ops.random_normal_initializer(stddev=0.1)) inputs = constant_op.constant(0.5, shape=[ input_sequence_length, batch_size, encoder_embedding_size ]) decoder_inputs = constant_op.constant( 0.4, shape=[ decoder_sequence_length, batch_size, decoder_embedding_size ]) decoder_length = constant_op.constant(decoder_sequence_length, dtype=dtypes.int32, shape=[ batch_size, ]) # attention attention_option = "luong" # can be "bahdanau" with variable_scope.variable_scope("rnn") as scope: # Define model encoder_outputs, encoder_state = rnn.dynamic_rnn( cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), inputs=inputs, dtype=dtypes.float32, time_major=True, scope=scope) # attention_states: size [batch_size, max_time, num_units] attention_states = array_ops.transpose( encoder_outputs, [1, 0, 2]) with variable_scope.variable_scope("decoder") as scope: # Prepare attention (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = ( attention_decoder_fn.prepare_attention( attention_states, attention_option, decoder_hidden_size)) decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state=encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn) # setting up weights for computing the final output def create_output_fn(): def output_fn(x): return layers.linear(x, num_decoder_symbols, scope=scope) return output_fn output_fn = create_output_fn() # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell( decoder_hidden_size) (decoder_outputs_train, decoder_state_train, _) = (seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_inputs, sequence_length=decoder_length, time_major=True, scope=scope)) decoder_outputs_train = output_fn(decoder_outputs_train) # Setup variable reuse scope.reuse_variables() # Inference decoder decoder_fn_inference = ( attention_decoder_fn.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=decoder_embeddings, start_of_sequence_id=start_of_sequence_id, end_of_sequence_id=end_of_sequence_id, maximum_length=decoder_sequence_length - 1, num_decoder_symbols=num_decoder_symbols, dtype=dtypes.int32)) (decoder_outputs_inference, decoder_state_inference, _) = (seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) # Run model variables.global_variables_initializer().run() (decoder_outputs_train_res, decoder_state_train_res) = sess.run( [decoder_outputs_train, decoder_state_train]) (decoder_outputs_inference_res, decoder_state_inference_res) = sess.run( [decoder_outputs_inference, decoder_state_inference]) # Assert outputs self.assertEqual( (decoder_sequence_length, batch_size, num_decoder_symbols), decoder_outputs_train_res.shape) self.assertEqual((batch_size, num_decoder_symbols), decoder_outputs_inference_res.shape[1:3]) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_train_res.shape) self.assertEqual((batch_size, decoder_hidden_size), decoder_state_inference_res.shape) # The dynamic decoder might end earlier than `maximal_length` # under inference self.assertGreaterEqual(decoder_sequence_length, decoder_state_inference_res.shape[0])
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, beam_size, embed, learning_rate=0.5, remove_unk=False, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=8, use_lstm=False): self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value='_UNK', shared_name="out_table", name="out_table", checkpoint=True) # build the vocab table (string to index) self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup( self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) with tf.variable_scope('decoder'): decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) self.decoder_output, _, _ = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") self.decoder_loss = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) with tf.variable_scope('decoder', reuse=True): decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, _ = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder_rnn") self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = self.index2symbol.lookup(self.generation_index, name='generation') with tf.variable_scope('decoder', reuse=True): decoder_fn_beam_inference = attention_decoder_fn_beam_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, beam_size, remove_unk) _, _, self.context_state = dynamic_rnn_decoder( cell, decoder_fn_beam_inference, scope="decoder_rnn") (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) = self.context_state self.beam_parents = tf.transpose(tf.reshape( beam_parents.stack(), [max_length + 1, -1, beam_size]), [1, 0, 2], name='beam_parents') self.beam_symbols = tf.transpose( tf.reshape(beam_symbols.stack(), [max_length + 1, -1, beam_size]), [1, 0, 2]) self.beam_symbols = self.index2symbol.lookup(tf.cast( self.beam_symbols, tf.int64), name="beam_symbols") self.result_probs = tf.transpose(tf.reshape( result_probs.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2], name='result_probs') self.result_symbols = tf.transpose( tf.reshape(result_symbols.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2]) self.result_parents = tf.transpose(tf.reshape( result_parents.stack(), [max_length + 1, -1, beam_size * 2]), [1, 0, 2], name='result_parents') self.result_symbols = self.index2symbol.lookup( tf.cast(self.result_symbols, tf.int64), name='result_symbols') self.params = tf.trainable_variables() # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # Exporter for serving self.model_exporter = exporter.Exporter(self.saver) inputs = {"enc_inps:0": self.posts, "enc_lens:0": self.posts_length} outputs = { "beam_symbols": self.beam_symbols, "beam_parents": self.beam_parents, "result_probs": self.result_probs, "result_symbols": self.result_symbols, "result_parents": self.result_parents } self.model_exporter.init(tf.get_default_graph().as_graph_def(), named_graph_signatures={ "inputs": exporter.generic_signature(inputs), "outputs": exporter.generic_signature(outputs) })
def __init__(self, num_symbols, num_qwords, #modify num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, question_data=True, learning_rate=0.5, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=False): self.posts = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.posts_length = tf.placeholder(tf.int32, shape=(None)) # batch self.responses = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=(None)) # batch self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol self.word_type = tf.placeholder(tf.int32, shape=(None)) #(batch * len) # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup(self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len #delete the last column of responses_target) and add 'GO at the front of it. self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len print "embedding..." # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: print len(vocab), len(embed), len(embed[0]) print embed # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) print "embedding finished" if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer(num_units, num_symbols, num_qwords, num_samples, question_data) print "encoder_output.shape:", encoder_output.get_shape() # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) # get decoding loop function decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn, self.keyword_tensor, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) if is_train: # rnn decoder self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") # calculate the loss of decoder # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output]) self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type) # building graph finished and get all parameters self.params = tf.trainable_variables() for item in tf.trainable_variables(): print item.name, item.get_shape() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference, scope="decoder") print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape()) self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)]) # generating the response self.generation_index = tf.argmax(tf.split(self.decoder_distribution, [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def test_dynamic_rnn_decoder(): with tf.Session() as sess: with tf.variable_scope( "root", initializer=tf.constant_initializer(0.5)) as varscope: batch_size = 2 encoder_embedding_size = 3 decoder_embedding_size = 4 encoder_hidden_size = 5 decoder_hidden_size = encoder_hidden_size input_sequence_length = 6 decoder_sequence_length = 7 num_decoder_symbols = 20 start_of_sequence_id = end_of_sequence_id = 1 decoder_embeddings = tf.get_variable( "decoder_embeddings", [num_decoder_symbols, decoder_embedding_size], initializer=tf.random_normal_initializer(stddev=0.1)) inputs = tf.constant(0.5, shape=[ input_sequence_length, batch_size, encoder_embedding_size ]) decoder_inputs = tf.constant(0.4, shape=[ decoder_sequence_length, batch_size, decoder_embedding_size ]) decoder_length = tf.constant(decoder_sequence_length, dtype=dtypes.int32, shape=[ batch_size, ]) with tf.variable_scope("rnn") as scope: # setting up weights for computing the final output output_fn = lambda x: layers.linear( x, num_decoder_symbols, scope=scope) # Define model encoder_outputs, encoder_state = rnn.dynamic_rnn( cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), inputs=inputs, dtype=dtypes.float32, time_major=True, scope=scope) with tf.variable_scope("decoder") as scope: # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size) decoder_fn_train = _decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_train( encoder_state=encoder_state)) (decoder_outputs_train, decoder_state_train, decoder_context_state_train) = seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_inputs, sequence_length=decoder_length, time_major=True, scope=scope) decoder_outputs_train = output_fn(decoder_outputs_train) # Setup variable reuse scope.reuse_variables() # Inference decoder decoder_fn_inference = _decoder_fn_with_context_state( decoder_fn_lib.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_state, embeddings=decoder_embeddings, start_of_sequence_id=start_of_sequence_id, end_of_sequence_id=end_of_sequence_id, maximum_length=decoder_sequence_length - 1, num_decoder_symbols=num_decoder_symbols, dtype=dtypes.int32)) (decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) output_train = tf.argmax(decoder_outputs_train, axis=2) output_inference = tf.argmax(decoder_outputs_inference, axis=2) tf.global_variables_initializer().run() (decoder_outputs_train_res, decoder_state_train_res, decoder_context_state_train_res) = sess.run([ decoder_outputs_train, decoder_state_train, decoder_context_state_train ]) (decoder_outputs_inference_res, decoder_state_inference_res, decoder_context_state_inference_res) = sess.run([ decoder_outputs_inference, decoder_state_inference, decoder_context_state_inference ]) print np.shape(decoder_outputs_train_res) print np.shape(decoder_outputs_inference_res) output_train, output_inference = sess.run( [output_train, output_inference]) print output_train print output_inference
def _build_graph(self): # build the graph self.graph = tf.Graph() with self.graph.as_default(): tf.set_random_seed(self.random_seed) # DATASET PLACEHOLDERS # (batch, time) source = tf.placeholder(tf.int32) source_mask = tf.placeholder(tf.float32) target = tf.placeholder(tf.int32) target_mask = tf.placeholder(tf.float32) output = tf.placeholder(tf.int32) output_mask = tf.placeholder(tf.float32) # TODO: add factored contexts (POS, NER, ETC...) # ner_context = tf.placeholder(tf.int32) # sets the probability of dropping out dropout_prob = tf.placeholder(tf.float32) with tf.name_scope('embeddings'): source_embeddings = tf.get_variable( "source_embeddings", [self.src_vocab_size, self.config['embedding_size']], trainable=True) # TODO: support factors for source and target inputs # ner_embeddings = tf.get_variable("ner_embeddings", [self.meta['num_ner_tags'], self.meta['ner_embedding_size']], # trainable=True) # default: just embed the tokens in the source context source_embed = tf.nn.embedding_lookup(source_embeddings, source) if self.use_ner_embeddings: pass # TODO: support factors for source input # ner_embed = tf.nn.embedding_lookup(ner_embeddings, ner_context) # context_embed = tf.concat([context_embed, ner_embed], 2) # context_embed.set_shape([None, None, self.meta['embedding_size'] + self.meta['ner_embedding_size']]) else: # this is to fix shape inference bug in rnn.py -- see this issue: https://github.com/tensorflow/tensorflow/issues/2938 source_embed.set_shape( [None, None, self.config['embedding_size']]) # TODO: switch this to target language embeddings # TODO: support target language factors (POS, NER, etc...) target_embeddings = tf.get_variable( "target_embeddings", [self.trg_vocab_size, self.config['embedding_size']]) # target embeddings - these are the _inputs_ to the decoder target_embed = tf.nn.embedding_lookup(target_embeddings, target) target_embed.set_shape( [None, None, self.config['embedding_size']]) # Construct input representation that we'll put attention over # Note: dropout is turned on/off by `dropout_prob` with tf.name_scope('input_representation'): lstm_cells = [ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell( self.config['encoder_hidden_size'], use_peepholes=True, state_is_tuple=True), input_keep_prob=dropout_prob, output_keep_prob=dropout_prob) for _ in range(self.config['lstm_stack_size']) ] cell = tf.contrib.rnn.MultiRNNCell(lstm_cells, state_is_tuple=True) # use the description mask to get the sequence lengths source_sequence_length = tf.cast(tf.reduce_sum(source_mask, 1), tf.int64) # BIDIRECTIONAL RNNs # Bidir outputs are (output_fw, output_bw) bidir_outputs, bidir_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell, cell_bw=cell, inputs=source_embed, sequence_length=source_sequence_length, dtype=tf.float32) l_to_r_states, r_to_l_states = bidir_state # Transpose to be time-major # TODO: do we need to transpose? # attention_states = tf.transpose(tf.concat(bidir_outputs, 2), [1, 0, 2]) attention_states = tf.concat(bidir_outputs, 2) # Note: encoder is bidirectional, so we reduce dimensionality by 1/2 to make decoder initial state init_state_transformation = tf.get_variable( 'decoder_init_transform', (self.config['encoder_hidden_size'] * 2, self.config['decoder_hidden_size'])) initialization_state = tf.matmul( tf.concat([r_to_l_states[-1][1], l_to_r_states[-1][1]], 1), init_state_transformation) # alternatively just use the final l_to_r state # initialization_state = l_to_r_states[-1][1] # TODO: try with simple L-->R GRU # encoder_outputs, encoder_state = rnn.dynamic_rnn( # cell=core_rnn_cell_impl.GRUCell(encoder_hidden_size), # inputs=inputs, # dtype=dtypes.float32, # time_major=False, # scope=scope) with tf.name_scope('target_representation'): target_lstm_cells = [ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell( self.config['encoder_hidden_size'], use_peepholes=True, state_is_tuple=True), input_keep_prob=dropout_prob, output_keep_prob=dropout_prob) for _ in range(self.config['lstm_stack_size']) ] target_cell = tf.contrib.rnn.MultiRNNCell(target_lstm_cells, state_is_tuple=True) # bidirectional target representation target_lengths = tf.cast(tf.reduce_sum(target_mask, axis=1), dtype=tf.int32) target_bidir_outputs, target_bidir_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=target_cell, cell_bw=target_cell, inputs=target_embed, sequence_length=target_lengths, dtype=tf.float32, scope='target_bidir_rnn') target_l_to_r_states, target_r_to_l_states = target_bidir_state target_representation = tf.concat(target_bidir_outputs, 2) # Now construct the decoder decoder_hidden_size = self.config['decoder_hidden_size'] # attention attention_option = "bahdanau" # can be "luong" with variable_scope.variable_scope("decoder") as scope: # Prepare attention (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = ( attention_decoder_fn.prepare_attention( attention_states, attention_option, decoder_hidden_size)) decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state=initialization_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn) # Note: this is different from the "normal" seq2seq encoder-decoder model, because we have different # input and output vocabularies for the decoder (target vocab vs. QE symbols) # num_decoder_symbols = self.output_vocab_size # decoder vocab is characters or sub-words? -- either way, we need to learn the vocab over the entity set # setting up weights for computing the final output # def create_output_fn(): # def output_fn(x): # return layers.linear(x, num_decoder_symbols, scope=scope) # return output_fn # output_fn = create_output_fn() intermediate_dim = 512 output_transformation_1 = tf.Variable( tf.random_normal([ self.config['decoder_hidden_size'] + self.config['encoder_hidden_size'] * 2, intermediate_dim ]), name='output_transformation_1') output_biases_1 = tf.Variable(tf.zeros([intermediate_dim]), name='output_biases_1') output_transformation_2 = tf.Variable( tf.random_normal( [intermediate_dim, self.output_vocab_size]), name='output_transformation_2') output_biases_2 = tf.Variable(tf.zeros( [self.output_vocab_size]), name='output_biases_2') # Train decoder decoder_cell = core_rnn_cell_impl.GRUCell(decoder_hidden_size) (decoder_outputs_train, decoder_state_train, _) = (seq2seq.dynamic_rnn_decoder( cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=target_embed, sequence_length=target_lengths, time_major=False, scope=scope)) # TODO: for attentive QE, we don't need to separate train and inference decoders # TODO: we can directly use train decoder output at both training and prediction time # concat with target lm representation decoder_outputs_train = tf.concat( [decoder_outputs_train, target_representation], 2) decoder_outputs_train = tf.nn.elu(decoder_outputs_train) decoder_outputs_train = tf.nn.dropout(decoder_outputs_train, keep_prob=dropout_prob) output_shape = tf.shape(decoder_outputs_train) decoder_outputs_train = tf.matmul( tf.reshape(decoder_outputs_train, [output_shape[0] * output_shape[1], -1]), output_transformation_1) decoder_outputs_train += output_biases_1 decoder_outputs_train = tf.nn.elu(decoder_outputs_train) decoder_outputs_train = tf.nn.dropout(decoder_outputs_train, keep_prob=dropout_prob) # one more linear layer decoder_outputs_train = tf.matmul(decoder_outputs_train, output_transformation_2) decoder_outputs_train += output_biases_2 decoder_outputs_train = tf.reshape( decoder_outputs_train, [output_shape[0], output_shape[1], -1]) # DEBUGGING: dump these # self.decoder_outputs_train = decoder_outputs_train with tf.name_scope('predictions'): prediction_logits = decoder_outputs_train logit_histo = tf.summary.histogram('prediction_logits', prediction_logits) predictions = tf.nn.softmax(prediction_logits) self.predictions = predictions # correct_predictions = tf.equal(tf.cast(tf.argmax(predictions, 1), tf.int32), entity) # accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32)) # accuracy_summary = tf.summary.scalar('accuracy', accuracy) with tf.name_scope('xent'): # Note: set output and output_mask shape because they're needed here: # https://github.com/tensorflow/tensorflow/blob/r1.0/tensorflow/contrib/seq2seq/python/ops/loss.py#L65-L70 output.set_shape([None, None]) output_mask.set_shape([None, None]) costs = tf.contrib.seq2seq.sequence_loss( logits=decoder_outputs_train, targets=output, weights=output_mask, average_across_timesteps=True) cost = tf.reduce_mean(costs) cost_summary = tf.summary.scalar('minibatch_cost', cost) # expose placeholders and ops on the class self.source = source self.source_mask = source_mask self.target = target self.target_mask = target_mask self.output = output self.output_mask = output_mask self.predictions = predictions self.cost = cost self.dropout_prob = dropout_prob # TODO: expose embeddings so that they can be visualized? optimizer = tf.train.AdamOptimizer() with tf.name_scope('train'): gradients = optimizer.compute_gradients( cost, tf.trainable_variables()) if self.config['max_gradient_norm'] is not None: gradients, variables = zip(*gradients) clipped_gradients, _ = clip_ops.clip_by_global_norm( gradients, self.config['max_gradient_norm']) gradients = list(zip(clipped_gradients, variables)) for gradient, variable in gradients: if isinstance(gradient, ops.IndexedSlices): grad_values = gradient.values else: grad_values = gradient tf.summary.histogram(variable.name, variable) tf.summary.histogram(variable.name + '/gradients', grad_values) tf.summary.histogram(variable.name + '/gradient_norm', clip_ops.global_norm([grad_values])) self.full_graph_optimizer = optimizer.apply_gradients( gradients) # Optimizer #2 -- updates entity representations only # entity_representation_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, # "representation/entity_lookup") # self.entity_representation_optimizer = optimizer.minimize(cost, # var_list=entity_representation_train_vars) self.saver = tf.train.Saver() # self.accuracy = accuracy self.merged = tf.summary.merge_all() logger.info('Finished building model graph')
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.da_vocab = api.dialog_act_vocab self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size with tf.name_scope("io"): # all dialog context and known attributes self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="dialog_context") self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="context_lens") self.topics = tf.placeholder(dtype=tf.int32, shape=(None,), name="topics") self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile") self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_lens") self.output_das = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_dialog_acts") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign(tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_dialog_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("topicEmbedding"): t_embedding = tf.get_variable("embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32) topic_embedding = embedding_ops.embedding_lookup(t_embedding, self.topics) if config.use_hcf: with variable_scope.variable_scope("dialogActEmbedding"): d_embedding = tf.get_variable("embedding", [self.da_vocab_size, config.da_embed_size], dtype=tf.float32) da_embedding = embedding_ops.embedding_lookup(d_embedding, self.output_das) with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable("embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant([0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup(embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape(input_embedding, [-1, self.max_utt_len, config.embed_size]) output_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens) if config.sent_type == "bow": input_embedding, sent_size = get_bow(input_embedding) output_embedding, _ = get_bow(output_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError("Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) # convert floors into 1 hot floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32) floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2]) joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding") with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, joint_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: enc_last_state = tf.concat(enc_last_state, 1) # combine with other attributes if config.use_hcf: attribute_embedding = da_embedding attribute_fc1 = layers.fully_connected(attribute_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1") cond_list = [topic_embedding, self.my_profile, self.ot_profile, enc_last_state] cond_embedding = tf.concat(cond_list, 1) with variable_scope.variable_scope("recognitionNetwork"): if config.use_hcf: recog_input = tf.concat([cond_embedding, output_embedding, attribute_fc1], 1) else: recog_input = tf.concat([cond_embedding, output_embedding], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected(recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): # P(XYZ)=P(Z|X)P(X)P(Y|X,Z) prior_fc1 = layers.fully_connected(cond_embedding, np.maximum(config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond(self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Y loss if config.use_hcf: meta_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="meta_fc1") if config.keep_prob <1.0: meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob) self.da_logits = layers.fully_connected(meta_fc1, self.da_vocab_size, scope="da_project") da_prob = tf.nn.softmax(self.da_logits) pred_attribute_embedding = tf.matmul(da_prob, d_embedding) if forward: selected_attribute_embedding = pred_attribute_embedding else: selected_attribute_embedding = attribute_embedding dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding], 1) else: self.da_logits = tf.zeros((batch_size, self.da_vocab_size)) dec_inputs = gen_inputs # Decoder if config.num_layer > 1: dec_init_state = [layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) for i in range(config.num_layer)] dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = rnn_cell.OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference(None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train(dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout(dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal(tf.random_uniform((batch_size, max_out_len-1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape(dec_input_embedding, [-1, max_out_len-1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder(dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops.shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply(tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.arg_max(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp(tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) """ as n-trial multimodal distribution. """ tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) # reconstruct the meta info about X if config.use_hcf: da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.da_logits, labels=self.output_das) self.avg_da_loss = tf.reduce_mean(da_loss) else: self.avg_da_loss = 0.0 kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum(tf.to_float(self.global_t)/config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo tf.summary.scalar("da_loss", self.avg_da_loss) tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)