def __init__(self, num_symbols, num_embed_units, num_units, num_layers, vocab=None, embed=None, name_scope=None, learning_rate=0.001, learning_rate_decay_factor=0.95, max_gradient_norm=5, num_samples=512, max_length=30): self.posts = tf.placeholder(tf.string, shape=[None, None]) # batch * len self.posts_length = tf.placeholder(tf.int32, shape=[None]) # batch self.responses = tf.placeholder(tf.string, shape=[None, None]) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=[None]) # batch self.weight = tf.placeholder(tf.float32, shape=[None]) # batch # build the vocab table (string to index) self.symbols = tf.Variable(vocab, trainable=False, name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.posts_input = self.symbol2index.lookup( self.posts) # batch * utter_len self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) # batch * utter_len * embed_unit self.responses_target = self.symbol2index.lookup( self.responses) # batch, len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch, len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # batch, len self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # Construct multi-layer GRU cells for encoder and decoder cell_enc = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) cell_dec = MultiRNNCell( [GRUCell(num_units) for _ in range(num_layers)]) # Encode the post sequence encoder_output, encoder_state = tf.nn.dynamic_rnn(cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = my_attention_decoder_fn.prepare_attention(encoder_output, 'bahdanau', num_units) # Decode the response sequence (Training) with variable_scope.variable_scope('decoder'): decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train( encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder( cell_dec, decoder_fn_train, self.decoder_input, self.responses_length, scope='decoder_rnn') self.decoder_loss = my_loss.sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask, softmax_loss_function=sampled_sequence_loss) self.weighted_decoder_loss = self.decoder_loss * self.weight attention_keys_infer, attention_values_infer, attention_score_fn_infer, attention_construct_fn_infer \ = my_attention_decoder_fn.prepare_attention(encoder_output, 'bahdanau', num_units, reuse = True) # Decode the response sequence (Inference) with variable_scope.variable_scope('decoder', reuse=True): decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state, attention_keys_infer, attention_values_infer, attention_score_fn_infer, attention_construct_fn_infer, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder( cell_dec, decoder_fn_inference, scope='decoder_rnn') self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = [ k for k in tf.trainable_variables() if name_scope in k.name ] # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.adv_global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters self.cost = tf.reduce_mean(self.weighted_decoder_loss) self.unweighted_cost = tf.reduce_mean(self.decoder_loss) opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.cost, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) all_variables = [ k for k in tf.global_variables() if name_scope in k.name ] self.saver = tf.train.Saver(all_variables, write_version=tf.train.SaverDef.V2, max_to_keep=5, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.adv_saver = tf.train.Saver(all_variables, write_version=tf.train.SaverDef.V2, max_to_keep=5, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_symbols, num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, learning_rate=0.1, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=True): self.posts_1 = tf.placeholder(tf.string, shape=(None, None)) self.posts_2 = tf.placeholder(tf.string, shape=(None, None)) self.posts_3 = tf.placeholder(tf.string, shape=(None, None)) self.posts_4 = tf.placeholder(tf.string, shape=(None, None)) self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3)) self.entity_mask_1 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_2 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_3 = tf.placeholder(tf.float32, shape=(None, None, None)) self.entity_mask_4 = tf.placeholder(tf.float32, shape=(None, None, None)) self.posts_length_1 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_2 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_3 = tf.placeholder(tf.int32, shape=(None)) self.posts_length_4 = tf.placeholder(tf.int32, shape=(None)) self.responses = tf.placeholder(tf.string, shape=(None, None)) self.responses_length = tf.placeholder(tf.int32, shape=(None)) self.epoch = tf.Variable(0, trainable=False, name='epoch') self.epoch_add_op = self.epoch.assign(self.epoch + 1) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.'] * num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input_1 = self.symbol2index.lookup(self.posts_1) self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup( self.posts_2) self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup( self.posts_3) self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup( self.posts_4) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.posts_input_2 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1], 1)[0] ], 1) self.posts_input_3 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1], 1)[0] ], 1) self.posts_input_4 = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1], 1)[0] ], 1) self.responses_target = self.symbol2index.lookup(self.responses) batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) self.encoder_2_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_2 - 1, tf.shape(self.posts_2)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_2)[1]]) self.encoder_3_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_3 - 1, tf.shape(self.posts_3)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_3)[1]]) self.encoder_4_mask = tf.reshape( tf.cumsum(tf.one_hot(self.posts_length_4 - 1, tf.shape(self.posts_4)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_4)[1]]) self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) if embed is None: self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input_1 = tf.nn.embedding_lookup(self.embed, self.posts_input_1) self.encoder_input_2 = tf.nn.embedding_lookup(self.embed, self.posts_input_2) self.encoder_input_3 = tf.nn.embedding_lookup(self.embed, self.posts_input_3) self.encoder_input_4 = tf.nn.embedding_lookup(self.embed, self.posts_input_4) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) entity_embedding_1 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_1)), [ batch_size, tf.shape(self.entity_1)[1], tf.shape(self.entity_1)[2], 3 * num_embed_units ]) entity_embedding_2 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_2)), [ batch_size, tf.shape(self.entity_2)[1], tf.shape(self.entity_2)[2], 3 * num_embed_units ]) entity_embedding_3 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_3)), [ batch_size, tf.shape(self.entity_3)[1], tf.shape(self.entity_3)[2], 3 * num_embed_units ]) entity_embedding_4 = tf.reshape( tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_4)), [ batch_size, tf.shape(self.entity_4)[1], tf.shape(self.entity_4)[2], 3 * num_embed_units ]) head_1, relation_1, tail_1 = tf.split(entity_embedding_1, [num_embed_units] * 3, axis=3) head_2, relation_2, tail_2 = tf.split(entity_embedding_2, [num_embed_units] * 3, axis=3) head_3, relation_3, tail_3 = tf.split(entity_embedding_3, [num_embed_units] * 3, axis=3) head_4, relation_4, tail_4 = tf.split(entity_embedding_4, [num_embed_units] * 3, axis=3) with tf.variable_scope('graph_attention'): #[batch_size, max_reponse_length, max_triple_num, 2*embed_units] head_tail_1 = tf.concat([head_1, tail_1], axis=3) #[batch_size, max_reponse_length, max_triple_num, embed_units] head_tail_transformed_1 = tf.layers.dense( head_tail_1, num_embed_units, activation=tf.tanh, name='head_tail_transform') #[batch_size, max_reponse_length, max_triple_num, embed_units] relation_transformed_1 = tf.layers.dense(relation_1, num_embed_units, name='relation_transform') #[batch_size, max_reponse_length, max_triple_num] e_weight_1 = tf.reduce_sum(relation_transformed_1 * head_tail_transformed_1, axis=3) #[batch_size, max_reponse_length, max_triple_num] alpha_weight_1 = tf.nn.softmax(e_weight_1) #[batch_size, max_reponse_length, embed_units] graph_embed_1 = tf.reduce_sum( tf.expand_dims(alpha_weight_1, 3) * (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_2 = tf.concat([head_2, tail_2], axis=3) head_tail_transformed_2 = tf.layers.dense( head_tail_2, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_2 = tf.layers.dense(relation_2, num_embed_units, name='relation_transform') e_weight_2 = tf.reduce_sum(relation_transformed_2 * head_tail_transformed_2, axis=3) alpha_weight_2 = tf.nn.softmax(e_weight_2) graph_embed_2 = tf.reduce_sum( tf.expand_dims(alpha_weight_2, 3) * (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_3 = tf.concat([head_3, tail_3], axis=3) head_tail_transformed_3 = tf.layers.dense( head_tail_3, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_3 = tf.layers.dense(relation_3, num_embed_units, name='relation_transform') e_weight_3 = tf.reduce_sum(relation_transformed_3 * head_tail_transformed_3, axis=3) alpha_weight_3 = tf.nn.softmax(e_weight_3) graph_embed_3 = tf.reduce_sum( tf.expand_dims(alpha_weight_3, 3) * (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3), axis=2) with tf.variable_scope('graph_attention', reuse=True): head_tail_4 = tf.concat([head_4, tail_4], axis=3) head_tail_transformed_4 = tf.layers.dense( head_tail_4, num_embed_units, activation=tf.tanh, name='head_tail_transform') relation_transformed_4 = tf.layers.dense(relation_4, num_embed_units, name='relation_transform') e_weight_4 = tf.reduce_sum(relation_transformed_4 * head_tail_transformed_4, axis=3) alpha_weight_4 = tf.nn.softmax(e_weight_4) graph_embed_4 = tf.reduce_sum( tf.expand_dims(alpha_weight_4, 3) * (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4), axis=2) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) encoder_output_1, encoder_state_1 = dynamic_rnn(cell, self.encoder_input_1, self.posts_length_1, dtype=tf.float32, scope="encoder") attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \ = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units) decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_1, attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1, max_length=tf.reduce_max(self.posts_length_2)) encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder( cell, decoder_fn_train_1, self.encoder_input_2, self.posts_length_2, scope="decoder") self.alignments_2 = tf.transpose(alignments_ta_2.stack(), perm=[1, 0, 2]) self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2, self.posts_2_target, self.encoder_2_mask) with variable_scope.variable_scope('', reuse=True): attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \ = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units) decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_2, attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2, max_length=tf.reduce_max(self.posts_length_3)) encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder( cell, decoder_fn_train_2, self.encoder_input_3, self.posts_length_3, scope="decoder") self.alignments_3 = tf.transpose(alignments_ta_3.stack(), perm=[1, 0, 2]) self.decoder_loss_3 = sampled_sequence_loss( encoder_output_3, self.posts_3_target, self.encoder_3_mask) attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \ = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units) decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train( encoder_state_3, attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3, max_length=tf.reduce_max(self.posts_length_4)) encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder( cell, decoder_fn_train_3, self.encoder_input_4, self.posts_length_4, scope="decoder") self.alignments_4 = tf.transpose(alignments_ta_4.stack(), perm=[1, 0, 2]) self.decoder_loss_4 = sampled_sequence_loss( encoder_output_4, self.posts_4_target, self.encoder_4_mask) attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units) if is_train: with variable_scope.variable_scope('', reuse=True): decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, max_length=tf.reduce_max(self.responses_length)) self.decoder_output, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2]) self.decoder_loss = sampled_sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask) self.params = tf.trainable_variables() self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) #opt = tf.train.GradientDescentOptimizer(self.learning_rate) opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) gradients = tf.gradients( self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 + self.decoder_loss_4, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: with variable_scope.variable_scope('', reuse=True): decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state_4, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder") output_len = tf.shape(self.decoder_distribution)[1] self.alignments = tf.transpose( alignments_ta.gather(tf.range(output_len)), [1, 0, 2]) self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index, name="generation") self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=10, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
class Model(BaseModel): def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.da_vocab = api.dialog_act_vocab self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.pad_id = self.rev_vocab["<pad>"] self.sos_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.unk_id = self.rev_vocab["<unk>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.latent_size = config.latent_size with tf.name_scope("io"): self.input_contexts = tf.placeholder(dtype=tf.string, shape=(None, None, None), name="dialog_context") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") self.topics = tf.placeholder(dtype=tf.int32, shape=(None, ), name="topics") self.output_tokens = tf.placeholder(dtype=tf.string, shape=(None, None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_lens") self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") batch_size = tf.shape(self.input_contexts)[0] max_dialog_len = tf.shape(self.input_contexts)[1] max_out_len = tf.shape(self.output_tokens)[2] with tf.variable_scope("tokenization"): self.symbols = tf.Variable(self.vocab, trainable=False, name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(self.vocab_size)], dtype=np.int32), False)), default_value=self.unk_id, name="symbol2index") self.contexts = self.symbol2index.lookup(self.input_contexts) self.responses_target = self.symbol2index.lookup( self.output_tokens) with tf.variable_scope("topic_embedding"): t_embedding = tf.get_variable( "embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32) topic_embedding = tf.nn.embedding_lookup(t_embedding, self.topics) # [batch_size, topic_embed_size] with tf.variable_scope("word_embedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.contexts, [-1])) input_embedding = tf.reshape( input_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) output_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.responses_target, [-1])) output_embedding = tf.reshape( output_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) with tf.variable_scope("uttrance_encoder"): if config.sent_type == "rnn": sent_cell = self.create_rnn_cell(self.sent_cell_size) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, tf.reshape( self.output_lens, [-1]), scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.create_rnn_cell(self.sent_cell_size) bwd_sent_cell = self.create_rnn_cell(self.sent_cell_size) input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, tf.reshape( self.output_lens, [-1]), scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [rnn, bi_rnn]") input_embedding = tf.reshape( input_embedding, [batch_size, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) output_embedding = tf.reshape( output_embedding, [batch_size, max_dialog_len, sent_size]) with tf.variable_scope("context_encoder"): enc_cell = self.create_rnn_cell(self.context_cell_size) cxt_outputs, _ = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) # [batch_size, max_dialog_len, context_cell_size] tile_topic_embedding = tf.reshape( tf.tile(topic_embedding, [1, max_dialog_len]), [batch_size, max_dialog_len, config.topic_embed_size]) cond_embedding = tf.concat([tile_topic_embedding, cxt_outputs], -1) # [batch_size, max_dialog_len, context_cell_size + topic_embed_size] with tf.variable_scope("posterior_network"): recog_input = tf.concat([cond_embedding, output_embedding], -1) post_sample, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2 = self.hierarchical_inference_net( recog_input) with tf.variable_scope("prior_network"): prior_input = cond_embedding prior_sample, prior_mu_1, prior_logvar_1, prior_mu_2, prior_logvar_2 = self.hierarchical_inference_net( prior_input) latent_sample = tf.cond(self.use_prior, lambda: prior_sample, lambda: post_sample) with tf.variable_scope("decoder"): dec_inputs = tf.concat([cond_embedding, latent_sample], -1) dec_inputs_dim = config.latent_size + config.topic_embed_size + self.context_cell_size dec_inputs = tf.reshape( dec_inputs, [batch_size * max_dialog_len, dec_inputs_dim]) dec_init_state = tf.contrib.layers.fully_connected( dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") dec_cell = self.create_rnn_cell(self.dec_cell_size) output_fn, sampled_sequence_loss = output_projection_layer( self.dec_cell_size, self.vocab_size) decoder_fn_train = decoder_fn.simple_decoder_fn_train( dec_init_state, dec_inputs) decoder_fn_inference = decoder_fn.simple_decoder_fn_inference( output_fn, dec_init_state, dec_inputs, embedding, self.sos_id, self.eos_id, max_out_len * 2, self.vocab_size) if forward: dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, decoder_fn_inference, scope="decoder") else: dec_input_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.responses_target, [-1])) dec_input_embedding = tf.reshape( dec_input_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = tf.reshape(self.output_lens, [-1]) - 1 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform( (batch_size * max_dialog_len, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, decoder_fn_train, dec_input_embedding, dec_seq_lens, scope="decoder") reshape_target = tf.reshape(self.responses_target, [batch_size * max_dialog_len, -1]) labels = reshape_target[:, 1:] label_mask = tf.to_float(tf.sign(labels)) local_loss = sampled_sequence_loss(dec_outs, labels, label_mask) if final_context_state is not None: final_context_state = final_context_state[:, 0:tf. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: dec_out_words = tf.argmax(dec_outs, 2) self.dec_out_words = tf.reshape( dec_out_words, [batch_size, max_dialog_len, -1])[:, -1, :] if not forward: with tf.variable_scope("loss"): self.avg_rc_loss = tf.reduce_mean(local_loss) self.rc_ppl = tf.reduce_sum(local_loss) self.total_word = tf.reduce_sum(label_mask) new_recog_mu_2 = tf.reshape(recog_mu_2, [-1, config.latent_size]) new_recog_logvar_2 = tf.reshape(recog_logvar_2, [-1, config.latent_size]) new_prior_mu_1 = tf.reshape(prior_mu_1, [-1, config.latent_size]) new_prior_logvar_1 = tf.reshape(prior_logvar_1, [-1, config.latent_size]) new_recog_mu_1 = tf.reshape(recog_mu_1, [-1, config.latent_size]) new_recog_logvar_1 = tf.reshape(recog_logvar_1, [-1, config.latent_size]) new_prior_mu_2 = tf.reshape(prior_mu_2, [-1, config.latent_size]) new_prior_logvar_2 = tf.reshape(prior_logvar_2, [-1, config.latent_size]) kld_1 = gaussian_kld(new_recog_mu_2, new_recog_logvar_2, new_prior_mu_1, new_prior_logvar_1) kld_2 = gaussian_kld(new_recog_mu_1, new_recog_logvar_1, new_prior_mu_2, new_prior_logvar_2) kld = kld_1 + kld_2 self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: self.kl_w = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: self.kl_w = tf.constant(1.0) aug_elbo = self.elbo = self.avg_rc_loss + self.kl_w * self.avg_kld tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) self.summary_op = tf.summary.merge_all() """ self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(- self.log_p_z + self.log_q_z_xy) """ self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) def batch_2_feed(self, batch, global_t, use_prior, repeat=1): context, context_lens, floors, topics, my_profiles, ot_profiles, outputs, output_lens, output_das = batch feed_dict = { self.input_contexts: context, self.context_lens: context_lens, self.topics: topics, self.output_tokens: outputs, self.output_lens: output_lens, self.use_prior: use_prior } if repeat > 1: tiled_feed_dict = {} for key, val in feed_dict.items(): if key is self.use_prior: tiled_feed_dict[key] = val continue multipliers = [1] * len(val.shape) multipliers[0] = repeat tiled_feed_dict[key] = np.tile(val, multipliers) feed_dict = tiled_feed_dict if global_t is not None: feed_dict[self.global_t] = global_t return feed_dict def train(self, global_t, sess, train_feed, update_limit=5000): elbo_losses = [] rc_losses = [] rc_ppls = [] total_words = [] kl_losses = [] local_t = 0 start_time = time.time() loss_names = ["elbo_loss", "rc_loss", "kl_loss"] while True: batch = train_feed.next_new_batch() if batch is None: break if update_limit is not None and local_t >= update_limit: break feed_dict = self.batch_2_feed(batch, global_t, use_prior=False) _, sum_op, elbo_loss, rc_loss, rc_ppl, kl_loss, total_word = sess.run( [ self.train_ops, self.summary_op, self.elbo, self.avg_rc_loss, self.rc_ppl, self.avg_kld, self.total_word ], feed_dict) self.train_summary_writer.add_summary(sum_op, global_t) total_words.append(total_word) elbo_losses.append(elbo_loss) rc_ppls.append(rc_ppl) rc_losses.append(rc_loss) kl_losses.append(kl_loss) global_t += 1 local_t += 1 if local_t % (train_feed.num_batch / 20) == 0: kl_w = sess.run(self.kl_w, {self.global_t: global_t}) self.print_loss( "%.2f" % (train_feed.ptr / float(train_feed.num_batch)), loss_names, [elbo_losses, rc_losses, kl_losses], "kl_w %f, perplexity: %f" % (kl_w, np.exp(np.sum(rc_ppls) / np.sum(total_words)))) # finish epoch! epoch_time = time.time() - start_time avg_losses = self.print_loss( "Epoch Done", loss_names, [elbo_losses, rc_losses, kl_losses], "step time %.4f, perplexity: %f" % (epoch_time / train_feed.num_batch, np.exp(np.sum(rc_ppls) / np.sum(total_words)))) return global_t, avg_losses[0] def valid(self, name, sess, valid_feed): elbo_losses = [] rc_losses = [] rc_ppls = [] kl_losses = [] total_words = [] while True: batch = valid_feed.next_new_batch() if batch is None: break feed_dict = self.batch_2_feed(batch, None, use_prior=False, repeat=1) elbo_loss, rc_loss, rc_ppl, kl_loss, total_word = sess.run([ self.elbo, self.avg_rc_loss, self.rc_ppl, self.avg_kld, self.total_word ], feed_dict) total_words.append(total_word) elbo_losses.append(elbo_loss) rc_losses.append(rc_loss) rc_ppls.append(rc_ppl) kl_losses.append(kl_loss) avg_losses = self.print_loss( name, ["elbo_loss", "rc_loss", "kl_loss"], [elbo_losses, rc_losses, kl_losses], "perplexity: %f" % np.exp(np.sum(rc_ppls) / np.sum(total_words))) return avg_losses[0] def test(self, sess, test_feed, num_batch=None, repeat=5, dest=sys.stdout): local_t = 0 recall_bleus = [] prec_bleus = [] while True: batch = test_feed.next_new_batch() if batch is None or (num_batch is not None and local_t > num_batch): break feed_dict = self.batch_2_feed(batch, None, use_prior=True, repeat=repeat) word_outs = sess.run(self.dec_out_words, feed_dict) sample_words = np.split(word_outs, repeat, axis=0) true_srcs = feed_dict[self.input_contexts] true_src_lens = feed_dict[self.context_lens] true_outs = feed_dict[self.output_tokens][:, -1, :] true_topics = feed_dict[self.topics] local_t += 1 if dest != sys.stdout: if local_t % (test_feed.num_batch / 10) == 0: print("%.2f >> " % (test_feed.ptr / float(test_feed.num_batch))), for b_id in range(test_feed.batch_size): dest.write( "Batch %d index %d of topic %s\n" % (local_t, b_id, self.topic_vocab[true_topics[b_id]])) start = np.maximum(0, true_src_lens[b_id] - 5) for t_id in range(start, true_srcs.shape[1], 1): src_str = " ".join([ w for w in true_srcs[b_id, t_id].tolist() if w not in ["<pad>"] ]) dest.write("Src %d: %s\n" % (t_id, src_str)) true_tokens = [ w for w in true_outs[b_id].tolist() if w not in ["<pad>", "<s>", "</s>"] ] true_str = " ".join(true_tokens).replace(" ' ", "'") dest.write("Target >> %s\n" % (true_str)) local_tokens = [] for r_id in range(repeat): pred_outs = sample_words[r_id] # pred_da = np.argmax(sample_das[r_id], axis=1)[0] pred_tokens = [ self.vocab[e] for e in pred_outs[b_id].tolist() if e not in [self.eos_id, self.pad_id, self.sos_id] ] pred_str = " ".join(pred_tokens).replace(" ' ", "'") dest.write("Sample %d >> %s\n" % (r_id, pred_str)) local_tokens.append(pred_tokens) max_bleu, avg_bleu = utils.get_bleu_stats( true_tokens, local_tokens) recall_bleus.append(max_bleu) prec_bleus.append(avg_bleu) dest.write("\n") avg_recall_bleu = float(np.mean(recall_bleus)) avg_prec_bleu = float(np.mean(prec_bleus)) avg_f1 = 2 * (avg_prec_bleu * avg_recall_bleu) / ( avg_prec_bleu + avg_recall_bleu + 10e-12) report = "Avg recall BLEU %f, avg precision BLEU %f and F1 %f (only 1 reference response. Not final result)" \ % (avg_recall_bleu, avg_prec_bleu, avg_f1) print report dest.write(report + "\n") print("Done testing") def hierarchical_inference_net(self, inputs): num_group = 2 group_dim = int(self.latent_size / 2) recog_mulogvar_1 = tf.contrib.layers.fully_connected( inputs, group_dim * 2, activation_fn=None, scope="muvar") recog_mu_1, recog_logvar_1 = tf.split(recog_mulogvar_1, 2, axis=-1) z_post_1 = sample_gaussian(recog_mu_1, recog_logvar_1) cont_inputs = tf.concat([z_post_1, inputs], -1) recog_mulogvar_2 = tf.contrib.layers.fully_connected( cont_inputs, group_dim * 2, activation_fn=None, scope="muvar1") recog_mu_2, recog_logvar_2 = tf.split(recog_mulogvar_2, 2, axis=-1) z_post_2 = sample_gaussian(recog_mu_2, recog_logvar_2) z_post = tf.concat([z_post_1, z_post_2], -1) return z_post, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.da_vocab = api.dialog_act_vocab self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.pad_id = self.rev_vocab["<pad>"] self.sos_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.unk_id = self.rev_vocab["<unk>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.latent_size = config.latent_size with tf.name_scope("io"): self.input_contexts = tf.placeholder(dtype=tf.string, shape=(None, None, None), name="dialog_context") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") self.topics = tf.placeholder(dtype=tf.int32, shape=(None, ), name="topics") self.output_tokens = tf.placeholder(dtype=tf.string, shape=(None, None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_lens") self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") batch_size = tf.shape(self.input_contexts)[0] max_dialog_len = tf.shape(self.input_contexts)[1] max_out_len = tf.shape(self.output_tokens)[2] with tf.variable_scope("tokenization"): self.symbols = tf.Variable(self.vocab, trainable=False, name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(self.vocab_size)], dtype=np.int32), False)), default_value=self.unk_id, name="symbol2index") self.contexts = self.symbol2index.lookup(self.input_contexts) self.responses_target = self.symbol2index.lookup( self.output_tokens) with tf.variable_scope("topic_embedding"): t_embedding = tf.get_variable( "embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32) topic_embedding = tf.nn.embedding_lookup(t_embedding, self.topics) # [batch_size, topic_embed_size] with tf.variable_scope("word_embedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.contexts, [-1])) input_embedding = tf.reshape( input_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) output_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.responses_target, [-1])) output_embedding = tf.reshape( output_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) with tf.variable_scope("uttrance_encoder"): if config.sent_type == "rnn": sent_cell = self.create_rnn_cell(self.sent_cell_size) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, tf.reshape( self.output_lens, [-1]), scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.create_rnn_cell(self.sent_cell_size) bwd_sent_cell = self.create_rnn_cell(self.sent_cell_size) input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, tf.reshape( self.output_lens, [-1]), scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [rnn, bi_rnn]") input_embedding = tf.reshape( input_embedding, [batch_size, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) output_embedding = tf.reshape( output_embedding, [batch_size, max_dialog_len, sent_size]) with tf.variable_scope("context_encoder"): enc_cell = self.create_rnn_cell(self.context_cell_size) cxt_outputs, _ = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) # [batch_size, max_dialog_len, context_cell_size] tile_topic_embedding = tf.reshape( tf.tile(topic_embedding, [1, max_dialog_len]), [batch_size, max_dialog_len, config.topic_embed_size]) cond_embedding = tf.concat([tile_topic_embedding, cxt_outputs], -1) # [batch_size, max_dialog_len, context_cell_size + topic_embed_size] with tf.variable_scope("posterior_network"): recog_input = tf.concat([cond_embedding, output_embedding], -1) post_sample, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2 = self.hierarchical_inference_net( recog_input) with tf.variable_scope("prior_network"): prior_input = cond_embedding prior_sample, prior_mu_1, prior_logvar_1, prior_mu_2, prior_logvar_2 = self.hierarchical_inference_net( prior_input) latent_sample = tf.cond(self.use_prior, lambda: prior_sample, lambda: post_sample) with tf.variable_scope("decoder"): dec_inputs = tf.concat([cond_embedding, latent_sample], -1) dec_inputs_dim = config.latent_size + config.topic_embed_size + self.context_cell_size dec_inputs = tf.reshape( dec_inputs, [batch_size * max_dialog_len, dec_inputs_dim]) dec_init_state = tf.contrib.layers.fully_connected( dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") dec_cell = self.create_rnn_cell(self.dec_cell_size) output_fn, sampled_sequence_loss = output_projection_layer( self.dec_cell_size, self.vocab_size) decoder_fn_train = decoder_fn.simple_decoder_fn_train( dec_init_state, dec_inputs) decoder_fn_inference = decoder_fn.simple_decoder_fn_inference( output_fn, dec_init_state, dec_inputs, embedding, self.sos_id, self.eos_id, max_out_len * 2, self.vocab_size) if forward: dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, decoder_fn_inference, scope="decoder") else: dec_input_embedding = tf.nn.embedding_lookup( embedding, tf.reshape(self.responses_target, [-1])) dec_input_embedding = tf.reshape( dec_input_embedding, [batch_size * max_dialog_len, -1, config.embed_size]) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = tf.reshape(self.output_lens, [-1]) - 1 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform( (batch_size * max_dialog_len, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, decoder_fn_train, dec_input_embedding, dec_seq_lens, scope="decoder") reshape_target = tf.reshape(self.responses_target, [batch_size * max_dialog_len, -1]) labels = reshape_target[:, 1:] label_mask = tf.to_float(tf.sign(labels)) local_loss = sampled_sequence_loss(dec_outs, labels, label_mask) if final_context_state is not None: final_context_state = final_context_state[:, 0:tf. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: dec_out_words = tf.argmax(dec_outs, 2) self.dec_out_words = tf.reshape( dec_out_words, [batch_size, max_dialog_len, -1])[:, -1, :] if not forward: with tf.variable_scope("loss"): self.avg_rc_loss = tf.reduce_mean(local_loss) self.rc_ppl = tf.reduce_sum(local_loss) self.total_word = tf.reduce_sum(label_mask) new_recog_mu_2 = tf.reshape(recog_mu_2, [-1, config.latent_size]) new_recog_logvar_2 = tf.reshape(recog_logvar_2, [-1, config.latent_size]) new_prior_mu_1 = tf.reshape(prior_mu_1, [-1, config.latent_size]) new_prior_logvar_1 = tf.reshape(prior_logvar_1, [-1, config.latent_size]) new_recog_mu_1 = tf.reshape(recog_mu_1, [-1, config.latent_size]) new_recog_logvar_1 = tf.reshape(recog_logvar_1, [-1, config.latent_size]) new_prior_mu_2 = tf.reshape(prior_mu_2, [-1, config.latent_size]) new_prior_logvar_2 = tf.reshape(prior_logvar_2, [-1, config.latent_size]) kld_1 = gaussian_kld(new_recog_mu_2, new_recog_logvar_2, new_prior_mu_1, new_prior_logvar_1) kld_2 = gaussian_kld(new_recog_mu_1, new_recog_logvar_1, new_prior_mu_2, new_prior_logvar_2) kld = kld_1 + kld_2 self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: self.kl_w = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: self.kl_w = tf.constant(1.0) aug_elbo = self.elbo = self.avg_rc_loss + self.kl_w * self.avg_kld tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) self.summary_op = tf.summary.merge_all() """ self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(- self.log_p_z + self.log_q_z_xy) """ self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, num_symbols, num_qwords, #modify num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, question_data=True, learning_rate=0.5, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=False): self.posts = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.posts_length = tf.placeholder(tf.int32, shape=(None)) # batch self.responses = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=(None)) # batch self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol self.word_type = tf.placeholder(tf.int32, shape=(None)) #(batch * len) # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup(self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len #delete the last column of responses_target) and add 'GO at the front of it. self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len print "embedding..." # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: print len(vocab), len(embed), len(embed[0]) print embed # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) print "embedding finished" if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer(num_units, num_symbols, num_qwords, num_samples, question_data) print "encoder_output.shape:", encoder_output.get_shape() # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) # get decoding loop function decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn, self.keyword_tensor, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) if is_train: # rnn decoder self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") # calculate the loss of decoder # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output]) self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type) # building graph finished and get all parameters self.params = tf.trainable_variables() for item in tf.trainable_variables(): print item.name, item.get_shape() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference, scope="decoder") print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape()) self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)]) # generating the response self.generation_index = tf.argmax(tf.split(self.decoder_distribution, [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_symbols, num_embed_units, num_units, is_train, vocab=None, content_pos=None, rhetoric_pos = None, embed=None, learning_rate=0.1, learning_rate_decay_factor=0.9995, max_gradient_norm=5.0, max_length=30, latent_size=128, use_lstm=False, num_classes=3, full_kl_step=80000, mem_slot_num=4, mem_size=128): self.ori_sents = tf.placeholder(tf.string, shape=(None, None)) self.ori_sents_length = tf.placeholder(tf.int32, shape=(None)) self.rep_sents = tf.placeholder(tf.string, shape=(None, None)) self.rep_sents_length = tf.placeholder(tf.int32, shape=(None)) self.labels = tf.placeholder(tf.float32, shape=(None, num_classes)) self.use_prior = tf.placeholder(tf.bool) self.global_t = tf.placeholder(tf.int32) self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0) self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0) topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32, shape=[None, mem_slot_num, mem_size]) w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32, initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1)) # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.ori_sents_input = self.symbol2index.lookup(self.ori_sents) self.rep_sents_target = self.symbol2index.lookup(self.rep_sents) batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1] self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID, tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1) self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input) if use_lstm: cell_fw = LSTMCell(num_units) cell_bw = LSTMCell(num_units) cell_dec = LSTMCell(2*num_units) else: cell_fw = GRUCell(num_units) cell_bw = GRUCell(num_units) cell_dec = GRUCell(2*num_units) # origin sentence encoder with variable_scope.variable_scope("encoder"): encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, self.ori_sents_length, dtype=tf.float32) post_sum_state = tf.concat(encoder_state, 1) encoder_output = tf.concat(encoder_output, 2) # response sentence encoder with variable_scope.variable_scope("encoder", reuse = True): decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, self.rep_sents_length, dtype=tf.float32) response_sum_state = tf.concat(decoder_last_state, 1) # recognition network with variable_scope.variable_scope("recog_net"): recog_input = tf.concat([post_sum_state, response_sum_state], 1) recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) # prior network with variable_scope.variable_scope("prior_net"): prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1") prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) latent_sample = tf.cond(self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) # classifier with variable_scope.variable_scope("classifier"): classifier_input = latent_sample pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1") self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits") self.label_embedding = tf.matmul(self.labels, self.pattern_embed) output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask) attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units) with variable_scope.variable_scope("dec_start"): temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1) dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1") dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2") if is_train: # rnn decoder topic_memory = self.update_memory(topic_memory, encoder_output) extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1) decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info) self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, self.decoder_input, self.rep_sents_length, scope = "decoder") # calculate the loss self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, targets = self.rep_sents_target, weights = self.decoder_mask, extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss) temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)) self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0) self.klloss = self.kl_weight * temp_klloss temp_labels = tf.argmax(self.labels, 1) self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels)) self.loss = self.decoder_loss + self.klloss + self.classifierloss # need to anneal the kl_weight # building graph finished and get all parameters self.params = tf.trainable_variables() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9) gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder topic_memory = self.update_memory(topic_memory, encoder_output) extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1) decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, dec_fc2, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info) self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder") self.generation_index = tf.argmax(tf.split(self.decoder_distribution, [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
def __init__(self, num_symbols, num_embed_units, num_units, vocab=None, embed=None, name_scope=None, learning_rate=0.0001, learning_rate_decay_factor=0.95, max_gradient_norm=5, l2_lambda=0.2): self.posts = tf.placeholder(tf.string, shape=[None, None]) # batch * len self.posts_length = tf.placeholder(tf.int32, shape=[None]) # batch self.responses = tf.placeholder(tf.string, shape=[None, None]) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=[None]) # batch self.generation = tf.placeholder(tf.string, shape=[None, None]) # batch*len self.generation_length = tf.placeholder(tf.int32, shape=[None]) # batch # build the vocab table (string to index) self.symbols = tf.Variable(vocab, trainable=False, name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.posts_input = self.symbol2index.lookup( self.posts) # batch * utter_len self.posts_input_embed = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch * utter_len * embed_unit self.responses_input = self.symbol2index.lookup(self.responses) self.responses_input_embed = tf.nn.embedding_lookup( self.embed, self.responses_input) # batch * utter_len * embed_unit self.generation_input = self.symbol2index.lookup(self.generation) self.generation_input_embed = tf.nn.embedding_lookup( self.embed, self.generation_input) # batch * utter_len * embed_unit # Construct bidirectional GRU cells for encoder / decoder cell_fw_post = GRUCell(num_units) cell_bw_post = GRUCell(num_units) cell_fw_resp = GRUCell(num_units) cell_bw_resp = GRUCell(num_units) # Encode the post sequence with variable_scope.variable_scope("post_encoder"): posts_state, posts_final_state = tf.nn.bidirectional_dynamic_rnn( cell_fw_post, cell_bw_post, self.posts_input_embed, self.posts_length, dtype=tf.float32) posts_final_state_bid = tf.concat( posts_final_state, 1) # batch_size * (2 * num_units) # Encode the real response sequence with variable_scope.variable_scope("resp_encoder"): responses_state, responses_final_state = tf.nn.bidirectional_dynamic_rnn( cell_fw_resp, cell_bw_resp, self.responses_input_embed, self.responses_length, dtype=tf.float32) responses_final_state_bid = tf.concat(responses_final_state, 1) # Encode the generated response sequence with variable_scope.variable_scope("resp_encoder", reuse=True): generation_state, generation_final_state = tf.nn.bidirectional_dynamic_rnn( cell_fw_resp, cell_bw_resp, self.generation_input_embed, self.generation_length, dtype=tf.float32) generation_final_state_bid = tf.concat(generation_final_state, 1) # Calculate the relevance score between post and real response with variable_scope.variable_scope("calibration"): self.W = tf.get_variable('W', [2 * num_units, 2 * num_units], tf.float32) vec_post = tf.reshape(posts_final_state_bid, [-1, 1, 2 * num_units]) vec_resp = tf.reshape(responses_final_state_bid, [-1, 2 * num_units, 1]) attn_score_true = tf.einsum( 'aij,ajk->aik', tf.einsum('aij,jk->aik', vec_post, self.W), vec_resp) attn_score_true = tf.reshape(attn_score_true, [-1, 1]) fc_true_input = tf.concat([ posts_final_state_bid, responses_final_state_bid, attn_score_true ], 1) self.output_fc_W = tf.get_variable("output_fc_W", [4 * num_units + 1, num_units], tf.float32) self.output_fc_b = tf.get_variable("output_fc_b", [num_units], tf.float32) fc_true = tf.nn.tanh( tf.nn.xw_plus_b(fc_true_input, self.output_fc_W, self.output_fc_b)) # batch_size self.output_W = tf.get_variable("output_W", [num_units, 1], tf.float32) self.output_b = tf.get_variable("output_b", [1], tf.float32) self.cost_true = tf.nn.sigmoid( tf.nn.xw_plus_b(fc_true, self.output_W, self.output_b)) # batch_size # Calculate the relevance score between post and generated response with variable_scope.variable_scope("calibration", reuse=True): vec_gen = tf.reshape(generation_final_state_bid, [-1, 2 * num_units, 1]) attn_score_false = tf.einsum( 'aij,ajk->aik', tf.einsum('aij,jk->aik', vec_post, self.W), vec_gen) attn_score_false = tf.reshape(attn_score_false, [-1, 1]) fc_false_input = tf.concat([ posts_final_state_bid, generation_final_state_bid, attn_score_false ], 1) fc_false = tf.nn.tanh( tf.nn.xw_plus_b(fc_false_input, self.output_fc_W, self.output_fc_b)) # batch_size self.cost_false = tf.nn.sigmoid( tf.nn.xw_plus_b(fc_false, self.output_W, self.output_b)) # batch_size self.PR_cost = tf.reduce_mean( tf.reduce_sum(tf.square(self.cost_true - 1.0), axis=1)) self.PG_cost = tf.reduce_mean( tf.reduce_sum(tf.square(self.cost_false), axis=1)) # Use the loss similar to least square GAN self.cost = self.PR_cost / 2.0 + self.PG_cost / 2.0 + l2_lambda * ( tf.nn.l2_loss(self.output_fc_W) + tf.nn.l2_loss(self.output_fc_b) + tf.nn.l2_loss(self.output_W) + tf.nn.l2_loss(self.output_b) + tf.nn.l2_loss(self.W)) # building graph finished and get all parameters self.params = [ k for k in tf.trainable_variables() if name_scope in k.name ] # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) self.adv_global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.cost, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) self.reward = tf.reduce_sum(self.cost_false, axis=1) # batch all_variables = [ k for k in tf.global_variables() if name_scope in k.name ] self.saver = tf.train.Saver(all_variables, write_version=tf.train.SaverDef.V2, max_to_keep=5, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.adv_saver = tf.train.Saver(all_variables, write_version=tf.train.SaverDef.V2, max_to_keep=5, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
class Seq2SeqModel(object): def __init__(self, num_symbols, num_embed_units, num_units, num_layers, is_train, vocab=None, embed=None, learning_rate=0.5, learning_rate_decay_factor=0.95, max_gradient_norm=5.0, num_samples=512, max_length=30, use_lstm=False): self.posts = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.posts_length = tf.placeholder(tf.int32, shape=(None)) # batch self.responses = tf.placeholder(tf.string, shape=(None, None)) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=(None)) # batch # build the vocab table (string to index) if is_train: self.symbols = tf.Variable(vocab, trainable=False, name="symbols") else: self.symbols = tf.Variable(np.array(['.'] * num_symbols), name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") self.posts_input = self.symbol2index.lookup(self.posts) # batch*len self.responses_target = self.symbol2index.lookup( self.responses) #batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) if use_lstm: cell = MultiRNNCell([LSTMCell(num_units)] * num_layers) else: cell = MultiRNNCell([GRUCell(num_units)] * num_layers) # rnn encoder encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") # get output projection function output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) # get attention function attention_keys, attention_values, attention_score_fn, attention_construct_fn \ = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units) # get decoding loop function decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train( encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn) decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference( output_fn, encoder_state, attention_keys, attention_values, attention_score_fn, attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols) if is_train: # rnn decoder self.decoder_output, _, _ = dynamic_rnn_decoder( cell, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder") # calculate the loss of decoder self.decoder_loss = sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # building graph finished and get all parameters self.params = tf.trainable_variables() # initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # calculate the gradient of parameters opt = tf.train.GradientDescentOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) else: # rnn decoder self.decoder_distribution, _, _ = dynamic_rnn_decoder( cell, decoder_fn_inference, scope="decoder") # generating the response #self.generation_index = tf.argmax(self.decoder_distribution, 2) self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = tf.trainable_variables() self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, forward_only=False): input_feed = { self.posts: data['posts'], self.posts_length: data['posts_length'], self.responses: data['responses'], self.responses_length: data['responses_length'] } if forward_only: output_feed = [self.decoder_loss] else: output_feed = [self.decoder_loss, self.gradient_norm, self.update] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.posts: data['posts'], self.posts_length: data['posts_length'] } output_feed = [self.generation] return session.run(output_feed, input_feed)
class lm_model(object): def __init__(self, num_symbols, num_embed_units, num_units, vocab=None, embed=None, name_scope=None, learning_rate=0.001, learning_rate_decay_factor=0.95, max_gradient_norm=5, num_samples=512, max_length=30): self.posts = tf.placeholder(tf.string, shape=[None, None]) # batch * len self.posts_length = tf.placeholder(tf.int32, shape=[None]) # batch self.responses = tf.placeholder(tf.string, shape=[None, None]) # batch*len self.responses_length = tf.placeholder(tf.int32, shape=[None]) # batch # build the vocab table (string to index) self.symbols = tf.Variable(vocab, trainable=False, name="symbols") self.symbol2index = HashTable(KeyValueTensorInitializer( self.symbols, tf.Variable( np.array([i for i in range(num_symbols)], dtype=np.int32), False)), default_value=UNK_ID, name="symbol2index") # build the embedding table (index to vector) if embed is None: # initialize the embedding randomly self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.posts_input = self.symbol2index.lookup( self.posts) # batch * utter_len self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) # batch * utter_len * embed_unit self.responses_target = self.symbol2index.lookup( self.responses) # batch*len batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape( self.responses)[1] self.responses_input = tf.concat([ tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID, tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0] ], 1) # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # batch * len self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # Construct GRU cells for encoder / decoder cell_enc = GRUCell(num_units) cell_dec = GRUCell(num_units) # Encode the post _, encoder_state = tf.nn.dynamic_rnn(cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder") output_fn, sampled_sequence_loss = output_projection_layer( num_units, num_symbols, num_samples) # Decode the response (training phase) with variable_scope.variable_scope('decoder'): decoder_fn_train = my_simple_decoder_fn.simple_decoder_fn_train( encoder_state) self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder( cell_dec, decoder_fn_train, self.decoder_input, self.responses_length, scope="decoder_rnn") self.decoder_loss, self.all_decoder_output = my_loss.sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask, softmax_loss_function=sampled_sequence_loss) # Decode the response (inference phase) with variable_scope.variable_scope('decoder', reuse=True): decoder_fn_inference = my_simple_decoder_fn.simple_decoder_fn_inference( output_fn, encoder_state, self.embed, GO_ID, EOS_ID, max_length, num_symbols) self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder( cell_dec, decoder_fn_inference, scope="decoder_rnn") self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, num_symbols - 2], 2)[1], 2) + 2 # for removing UNK self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index) self.params = [ k for k in tf.trainable_variables() if name_scope in k.name ] # Initialize the training process self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * learning_rate_decay_factor) self.global_step = tf.Variable(0, trainable=False) # Calculate the gradient of parameters self.cost = tf.reduce_mean(self.decoder_loss) opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.cost, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, max_gradient_norm) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) all_variables = [ k for k in tf.global_variables() if name_scope in k.name ] self.saver = tf.train.Saver(all_variables, write_version=tf.train.SaverDef.V2, max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) # Conduct a step of training def step(self, session, data, forward_only=False): input_feed = { self.posts: data['query'], self.posts_length: data['len_query'], self.responses: data['ans'], self.responses_length: data['len_ans'] } if forward_only: output_feed = [self.cost] else: output_feed = [self.cost, self.gradient_norm, self.update] return session.run(output_feed, input_feed) # Get the language model score during inference def inference(self, session, data): input_feed = { self.posts: data['query'], self.posts_length: data['len_query'], self.responses: data['ans'], self.responses_length: data['len_ans'] } output_feed = [self.all_decoder_output] return session.run(output_feed, input_feed) # Acquire a batch of data used for training / test def gen_train_batched_data(self, data, config): len_query = [len(p['query']) + 1 for p in data] len_ans = [len(p['ans']) + 1 for p in data] def padding(sent, l, is_query=False): if config.direction == 0 and is_query == False: sent.reverse() return sent + ['_EOS'] + ['_PAD'] * (l - len(sent) - 1) batched_query = [ padding(p['query'], max(len_query), True) for p in data ] batched_ans = [padding(p['ans'], max(len_ans)) for p in data] batched_data = { 'query': np.array(batched_query), 'len_query': np.array(len_query, dtype=np.int32), 'ans': np.array(batched_ans), 'len_ans': np.array(len_ans, dtype=np.int32) } return batched_data