def testBasicRNNSeq2Seq(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 cell = core_rnn_cell_impl.OutputProjectionWrapper( core_rnn_cell_impl.GRUCell(2), 4) dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 4), res[0].shape) res = sess.run([mem]) self.assertEqual((2, 2), res[0].shape)
def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3]) cell = core_rnn_cell_impl.OutputProjectionWrapper( core_rnn_cell_impl.GRUCell(3), 2) g, new_m = cell(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, new_m], { x.name: np.array([[1., 1., 1.]]), m.name: np.array([[0.1, 0.1, 0.1]]) }) self.assertEqual(res[1].shape, (1, 3)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testRNNDecoder(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 _, enc_state = core_rnn.static_rnn( core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 cell = core_rnn_cell_impl.OutputProjectionWrapper( core_rnn_cell_impl.GRUCell(2), 4) dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) self.assertEqual((2, 4), res[0].shape) res = sess.run([mem]) self.assertEqual((2, 2), res[0].shape)
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.da_vocab = api.dialog_act_vocab self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size with tf.name_scope("io"): # all dialog context and known attributes self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="dialog_context") self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="context_lens") self.topics = tf.placeholder(dtype=tf.int32, shape=(None,), name="topics") self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile") self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_lens") self.output_das = tf.placeholder(dtype=tf.int32, shape=(None,), name="output_dialog_acts") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign(tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_dialog_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("topicEmbedding"): t_embedding = tf.get_variable("embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32) topic_embedding = embedding_ops.embedding_lookup(t_embedding, self.topics) if config.use_hcf: with variable_scope.variable_scope("dialogActEmbedding"): d_embedding = tf.get_variable("embedding", [self.da_vocab_size, config.da_embed_size], dtype=tf.float32) da_embedding = embedding_ops.embedding_lookup(d_embedding, self.output_das) with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable("embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant([0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup(embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape(input_embedding, [-1, self.max_utt_len, config.embed_size]) output_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens) if config.sent_type == "bow": input_embedding, sent_size = get_bow(input_embedding) output_embedding, _ = get_bow(output_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError("Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) # convert floors into 1 hot floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32) floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2]) joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding") with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, joint_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: enc_last_state = tf.concat(enc_last_state, 1) # combine with other attributes if config.use_hcf: attribute_embedding = da_embedding attribute_fc1 = layers.fully_connected(attribute_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1") cond_list = [topic_embedding, self.my_profile, self.ot_profile, enc_last_state] cond_embedding = tf.concat(cond_list, 1) with variable_scope.variable_scope("recognitionNetwork"): if config.use_hcf: recog_input = tf.concat([cond_embedding, output_embedding, attribute_fc1], 1) else: recog_input = tf.concat([cond_embedding, output_embedding], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected(recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): # P(XYZ)=P(Z|X)P(X)P(Y|X,Z) prior_fc1 = layers.fully_connected(cond_embedding, np.maximum(config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond(self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Y loss if config.use_hcf: meta_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="meta_fc1") if config.keep_prob <1.0: meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob) self.da_logits = layers.fully_connected(meta_fc1, self.da_vocab_size, scope="da_project") da_prob = tf.nn.softmax(self.da_logits) pred_attribute_embedding = tf.matmul(da_prob, d_embedding) if forward: selected_attribute_embedding = pred_attribute_embedding else: selected_attribute_embedding = attribute_embedding dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding], 1) else: self.da_logits = tf.zeros((batch_size, self.da_vocab_size)) dec_inputs = gen_inputs # Decoder if config.num_layer > 1: dec_init_state = [layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) for i in range(config.num_layer)] dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = rnn_cell.OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference(None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train(dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup(embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout(dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal(tf.random_uniform((batch_size, max_out_len-1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape(dec_input_embedding, [-1, max_out_len-1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder(dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops.shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply(tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.arg_max(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp(tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) """ as n-trial multimodal distribution. """ tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) # reconstruct the meta info about X if config.use_hcf: da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.da_logits, labels=self.output_das) self.avg_da_loss = tf.reduce_mean(da_loss) else: self.avg_da_loss = 0.0 kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum(tf.to_float(self.global_t)/config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo tf.summary.scalar("da_loss", self.avg_da_loss) tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)