def __init__(self, sess, config, api, log_dir, forward, scope=None): # forward??? self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.seen_intent = api.seen_intent self.rev_seen_intent = api.rev_seen_intent self.seen_intent_size = len(self.rev_seen_intent) self.unseen_intent = api.unseen_intent self.rev_unseen_intent = api.rev_unseen_intent self.unseen_intent_size = len(self.rev_unseen_intent) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.label_embed_size = config.label_embed_size self.latent_size = config.latent_size self.seed = config.seed self.use_ot_label = config.use_ot_label self.use_rand_ot_label = config.use_rand_ot_label # Only valid if use_ot_label is true, whether use all other label self.use_rand_fixed_ot_label = config.use_rand_fixed_ot_label # valid when use_ot_label=true and use_rand_ot_label=true if self.use_ot_label: self.rand_ot_label_num = config.rand_ot_label_num # valid when use_ot_label=true and use_rand_ot_label=true else: self.rand_ot_label_num = self.seen_intent_size - 1 with tf.name_scope("io"): # all dialog context and known attributes self.labels = tf.placeholder( dtype=tf.int32, shape=(None, ), name="labels") # each utterance have a label, [batch_size,] self.ot_label_rand = tf.placeholder(dtype=tf.int32, shape=(None, None), name="ot_labels_rand") self.ot_labels_all = tf.placeholder( dtype=tf.int32, shape=(None, None), name="ot_labels_all") #(batch_size, len(api.label_vocab)-1) # target response given the dialog context self.io_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_tokens") self.io_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") self.output_labels = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_labels") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder( dtype=tf.bool, name="use_prior") # whether use prior self.prior_mulogvar = tf.placeholder( dtype=tf.float32, shape=(None, config.latent_size * 2), name="prior_mulogvar") self.batch_size = tf.placeholder(dtype=tf.int32, name="batch_size") max_out_len = array_ops.shape(self.io_tokens)[1] # batch_size = array_ops.shape(self.io_tokens)[0] batch_size = self.batch_size with variable_scope.variable_scope("labelEmbedding", reuse=tf.AUTO_REUSE): self.la_embedding = tf.get_variable( "embedding", [self.seen_intent_size, config.label_embed_size], dtype=tf.float32) label_embedding = embedding_ops.embedding_lookup( self.la_embedding, self.output_labels) # not use with variable_scope.variable_scope("wordEmbedding", reuse=tf.AUTO_REUSE): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32, trainable=False) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask # boardcast, first row is all 0. io_embedding = embedding_ops.embedding_lookup( embedding, self.io_tokens) # 3 dim if config.sent_type == "bow": io_embedding, _ = get_bow(io_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) io_embedding, _ = get_rnn_encode(io_embedding, sent_cell, self.io_lens, scope="sent_rnn", reuse=tf.AUTO_REUSE) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) io_embedding, _ = get_bi_rnn_encode( io_embedding, fwd_sent_cell, bwd_sent_cell, self.io_lens, scope="sent_bi_rnn", reuse=tf.AUTO_REUSE ) # equal to x of the graph, (batch_size, 300*2) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # print('==========================', io_embedding) # Tensor("models_2/wordEmbedding/sent_bi_rnn/concat:0", shape=(?, 600), dtype=float32) # convert label into 1 hot my_label_one_hot = tf.one_hot(tf.reshape(self.labels, [-1]), depth=self.seen_intent_size, dtype=tf.float32) # 2 dim if config.use_ot_label: if config.use_rand_ot_label: ot_label_one_hot = tf.one_hot(tf.reshape( self.ot_label_rand, [-1]), depth=self.seen_intent_size, dtype=tf.float32) ot_label_one_hot = tf.reshape( ot_label_one_hot, [-1, self.seen_intent_size * self.rand_ot_label_num]) else: ot_label_one_hot = tf.one_hot(tf.reshape( self.ot_labels_all, [-1]), depth=self.seen_intent_size, dtype=tf.float32) ot_label_one_hot = tf.reshape( ot_label_one_hot, [ -1, self.seen_intent_size * (self.seen_intent_size - 1) ] ) # (batch_size, len(api.label_vocab)*(len(api.label_vocab)-1)) with variable_scope.variable_scope("recognitionNetwork", reuse=tf.AUTO_REUSE): recog_input = io_embedding self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") # config.latent_size=200 recog_mu, recog_logvar = tf.split( recog_mulogvar, 2, axis=1 ) # recognition network output. (batch_size, config.latent_size) with variable_scope.variable_scope("priorNetwork", reuse=tf.AUTO_REUSE): # p(xyz) = p(z)p(x|z)p(y|xz) # prior network parameter, assum the normal distribution # prior_mulogvar = tf.constant([[1] * config.latent_size + [0] * config.latent_size]*batch_size, # dtype=tf.float32, name="muvar") # can not use by this manner prior_mulogvar = self.prior_mulogvar prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond( self.use_prior, # bool input lambda: sample_gaussian(prior_mu, prior_logvar ), # equal to shape(prior_logvar) lambda: sample_gaussian(recog_mu, recog_logvar) ) # if ... else ..., (batch_size, config.latent_size) self.z = latent_sample with variable_scope.variable_scope("generationNetwork", reuse=tf.AUTO_REUSE): bow_loss_inputs = latent_sample # (part of) response network input label_inputs = latent_sample dec_inputs = latent_sample # BOW loss if config.use_bow_loss: bow_fc1 = layers.fully_connected( bow_loss_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") # MLPb network fc layer # error1:ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`. if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected( bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # MLPb network fc output # Y loss, include the other y. my_label_fc1 = layers.fully_connected(label_inputs, 400, activation_fn=tf.tanh, scope="my_label_fc1") if config.keep_prob < 1.0: my_label_fc1 = tf.nn.dropout(my_label_fc1, config.keep_prob) # my_label_fc2 = layers.fully_connected(my_label_fc1, 400, activation_fn=tf.tanh, scope="my_label_fc2") # if config.keep_prob < 1.0: # my_label_fc2 = tf.nn.dropout(my_label_fc2, config.keep_prob) self.my_label_logits = layers.fully_connected( my_label_fc1, self.seen_intent_size, scope="my_label_project") # MLPy fc output my_label_prob = tf.nn.softmax( self.my_label_logits ) # softmax output, (batch_size, label_vocab_size) self.my_label_prob = my_label_prob pred_my_label_embedding = tf.matmul( my_label_prob, self.la_embedding ) # predicted my label y. (batch_size, label_embed_size) if config.use_ot_label: if config.use_rand_ot_label: # use one random other label ot_label_fc1 = layers.fully_connected( label_inputs, 400, activation_fn=tf.tanh, scope="ot_label_fc1") if config.keep_prob < 1.0: ot_label_fc1 = tf.nn.dropout(ot_label_fc1, config.keep_prob) self.ot_label_logits = layers.fully_connected( ot_label_fc1, self.rand_ot_label_num * self.seen_intent_size, scope="ot_label_rand_project") ot_label_logits_split = tf.reshape( self.ot_label_logits, [-1, self.rand_ot_label_num, self.seen_intent_size]) ot_label_prob_short = tf.nn.softmax(ot_label_logits_split) ot_label_prob = tf.reshape( ot_label_prob_short, [-1, self.rand_ot_label_num * self.seen_intent_size] ) # (batch_size, self.rand_ot_label_num*self.label_vocab_size) pred_ot_label_embedding = tf.reshape( tf.matmul(ot_label_prob_short, self.la_embedding), [self.label_embed_size * self.rand_ot_label_num ]) # predicted other label y2. else: ot_label_fc1 = layers.fully_connected( label_inputs, 400, activation_fn=tf.tanh, scope="ot_label_fc1") if config.keep_prob < 1.0: ot_label_fc1 = tf.nn.dropout(ot_label_fc1, config.keep_prob) self.ot_label_logits = layers.fully_connected( ot_label_fc1, self.seen_intent_size * (self.seen_intent_size - 1), scope="ot_label_all_project") ot_label_logits_split = tf.reshape( self.ot_label_logits, [-1, self.seen_intent_size - 1, self.seen_intent_size]) ot_label_prob_short = tf.nn.softmax(ot_label_logits_split) ot_label_prob = tf.reshape( ot_label_prob_short, [ -1, self.seen_intent_size * (self.seen_intent_size - 1) ] ) # (batch_size, self.label_vocab_size*(self.label_vocab_size-1)) pred_ot_label_embedding = tf.reshape( tf.matmul(ot_label_prob_short, self.la_embedding), [self.label_embed_size * (self.seen_intent_size - 1)] ) # predicted other all label y. (batch_size, self.label_embed_size*(self.label_vocab_size-1)) # note:matmul can calc (3, 4, 5) × (5, 4) = (3, 4, 4) else: # only use label y. self.ot_label_logits = None pred_ot_label_embedding = None # Decoder, Response Network if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("decoder", reuse=tf.AUTO_REUSE): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: # test loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=None) # a function dec_input_embedding = None dec_seq_lens = None else: # train loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, None) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.io_tokens ) # x 's embedding (batch_size, utt_len, embed_size) dec_input_embedding = dec_input_embedding[:, 0: -1, :] # ignore the last </s> dec_seq_lens = self.io_lens - 1 # input placeholder if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) # print("=======", dec_input_embedding) # Tensor("models/decoder/strided_slice:0", shape=(?, ?, 200), dtype=float32) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens ) # dec_outs [batch_size, seq, features] if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max( dec_outs, axis=2))) # get softmax vec's max index self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax( dec_outs, 2) # (batch_size, utt_len), each element is index of word if not forward: with variable_scope.variable_scope("loss", reuse=tf.AUTO_REUSE): labels = self.io_tokens[:, 1:] # not include the first word <s>, (batch_size, utt_len) label_mask = tf.to_float(tf.sign(labels)) labels = tf.one_hot(labels, depth=self.vocab_size, dtype=tf.float32) print(dec_outs) print(labels) # Tensor("models_1/decoder/dynamic_rnn_decoder/transpose_1:0", shape=(?, ?, 892), dtype=float32) # Tensor("models_1/loss/strided_slice:0", shape=(?, ?), dtype=int32) # rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) # response network loss rc_loss = tf.nn.softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) # response network loss # logits_size=[390,892] labels_size=[1170,892] rc_loss = tf.reduce_sum( rc_loss * label_mask, reduction_indices=1) # (batch_size,), except the word unk self.avg_rc_loss = tf.reduce_mean(rc_loss) # scalar # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) """ as n-trial multimodal distribution. """ tile_bow_logits = tf.tile( tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1 ]) # (batch_size, max_out_len-1, vocab_size) bow_loss = tf.nn.softmax_cross_entropy_with_logits( logits=tile_bow_logits, labels=labels ) * label_mask # labels shape less than logits shape, (batch_size, max_out_len-1) bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) # (batch_size, ) self.avg_bow_loss = tf.reduce_mean(bow_loss) # scalar # the label y my_label_loss = tf.nn.softmax_cross_entropy_with_logits( logits=my_label_prob, labels=my_label_one_hot) # label (batch_size,) self.avg_my_label_loss = tf.reduce_mean(my_label_loss) if config.use_ot_label: ot_label_loss = -tf.nn.softmax_cross_entropy_with_logits( logits=ot_label_prob, labels=ot_label_one_hot) self.avg_ot_label_loss = tf.reduce_mean(ot_label_loss) else: self.avg_ot_label_loss = 0.0 kld = gaussian_kld( recog_mu, recog_logvar, prior_mu, prior_logvar) # kl divergence, (batch_size,) self.avg_kld = tf.reduce_mean(kld) # scalar if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld # Restructure loss and kl divergence #=====================================================================================================total loss====================================================# if config.use_rand_ot_label: aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo # augmented loss # (1/self.rand_ot_label_num)* else: aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo # augmented loss # (1/(self.label_vocab_size-1))* tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) tf.summary.scalar("my_label_loss", self.avg_my_label_loss) tf.summary.scalar("ot_label_loss", self.avg_ot_label_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) # probability self.log_q_z_xy = norm_log_liklihood( latent_sample, recog_mu, recog_logvar) # probability self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print('model establish finish!')
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.bow_weights = config.bow_weights with tf.name_scope("io"): # all dialog context and known attributes self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="context") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") self.output_topics = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_topic") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_context_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] if config.use_hcf: with variable_scope.variable_scope("topicEmbedding"): t_embedding = tf.get_variable( "embedding", [self.topic_vocab_size, config.topic_embed_size], dtype=tf.float32) topic_embedding = embedding_ops.embedding_lookup( t_embedding, self.output_topics) with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape( input_embedding, [-1, self.max_utt_len, config.embed_size]) output_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) # context nn if config.sent_type == "bow": input_embedding, sent_size = get_bow(input_embedding) output_embedding, _ = get_bow(output_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs input_embedding = tf.reshape(input_embedding, [-1, max_context_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: enc_last_state = tf.concat(enc_last_state, 1) # combine with other attributes if config.use_hcf: attribute_embedding = topic_embedding attribute_fc1 = layers.fully_connected(attribute_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1") cond_embedding = enc_last_state with variable_scope.variable_scope("recognitionNetwork"): if config.use_hcf: recog_input = tf.concat( [cond_embedding, output_embedding, attribute_fc1], 1) else: recog_input = tf.concat([cond_embedding, output_embedding], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): prior_fc1 = layers.fully_connected(cond_embedding, np.maximum( config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Y loss if config.use_hcf: meta_fc1 = layers.fully_connected(latent_sample, 400, activation_fn=tf.tanh, scope="meta_fc1") if config.keep_prob < 1.0: meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob) self.topic_logits = layers.fully_connected( meta_fc1, self.topic_vocab_size, scope="topic_project") topic_prob = tf.nn.softmax(self.topic_logits) #pred_attribute_embedding = tf.matmul(topic_prob, t_embedding) pred_topic = tf.argmax(topic_prob, 1) pred_attribute_embedding = embedding_ops.embedding_lookup( t_embedding, pred_topic) if forward: selected_attribute_embedding = pred_attribute_embedding else: selected_attribute_embedding = attribute_embedding dec_inputs = tf.concat( [gen_inputs, selected_attribute_embedding], 1) else: self.topic_logits = tf.zeros( (batch_size, self.topic_vocab_size)) selected_attribute_embedding = None dec_inputs = gen_inputs # Decoder if config.num_layer > 1: dec_init_state = [ layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) for i in range(config.num_layer) ] dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) """ as n-trial multimodal distribution. """ tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) bow_weights = tf.to_float(self.bow_weights) # reconstruct the meta info about X if config.use_hcf: topic_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.topic_logits, labels=self.output_topics) self.avg_topic_loss = tf.reduce_mean(topic_loss) else: self.avg_topic_loss = 0.0 kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = bow_weights * self.avg_bow_loss + self.avg_topic_loss + self.elbo tf.summary.scalar("topic_loss", self.avg_topic_loss) tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab # index2word self.rev_vocab = api.rev_vocab # word2index self.vocab_size = len(self.vocab) # vocab size self.emotion_vocab = api.emotion_vocab # index2emotion self.emotion_vocab_size = len(self.emotion_vocab) # self.da_vocab = api.dialog_act_vocab # self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size # dont need self.sent_cell_size = config.sent_cell_size # for encode self.dec_cell_size = config.dec_cell_size # for decode with tf.name_scope("io"): self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, self.max_utt_len), name="input_contexts") # self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor") self.input_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="input_lens") self.input_emotions = tf.placeholder(dtype=tf.int32, shape=(None, ), name="input_emotions") # self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile") # self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") self.output_emotions = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_emotions") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_dialog_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("emotionEmbedding"): t_embedding = tf.get_variable( "embedding", [self.emotion_vocab_size, config.topic_embed_size], dtype=tf.float32) inp_emotion_embedding = embedding_ops.embedding_lookup( t_embedding, self.input_emotions) outp_emotion_embedding = embedding_ops.embedding_lookup( t_embedding, self.output_emotions) with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape( input_embedding, [-1, self.max_utt_len, config.embed_size]) output_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) if config.sent_type == "rnn": enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.input_lens) sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) # input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn") elif config.sent_type == "bi_rnn": fwd_enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) bwd_enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) _, enc_last_state = tf.nn.bidirectional_dynamic_rnn( fwd_enc_cell, bwd_enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.input_lens) enc_last_state = enc_last_state[0] + enc_last_state[1] fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) # input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn") else: raise ValueError( "Unknown sent_type. Must be one of [rnn, bi_rnn]") # reshape input into dialogs # input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size]) # if config.keep_prob < 1.0: # input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) # convert floors into 1 hot # floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32) # floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2]) # joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding") with variable_scope.variable_scope("contextRNN"): if config.num_layer > 1: if config.cell_type == 'lstm': enc_last_state = [temp.h for temp in enc_last_state] enc_last_state = tf.concat(enc_last_state, 1) else: if config.cell_type == 'lstm': enc_last_state = enc_last_state.h attribute_fc1 = layers.fully_connected(outp_emotion_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1") cond_embedding = tf.concat([inp_emotion_embedding, enc_last_state], 1) with variable_scope.variable_scope("recognitionNetwork"): recog_input = tf.concat( [cond_embedding, output_embedding, attribute_fc1], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): prior_fc1 = layers.fully_connected(cond_embedding, np.maximum( config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Y loss meta_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="meta_fc1") if config.keep_prob < 1.0: meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob) self.da_logits = layers.fully_connected(meta_fc1, self.emotion_vocab_size, scope="da_project") da_prob = tf.nn.softmax(self.da_logits) pred_attribute_embedding = tf.matmul(da_prob, t_embedding) if forward: selected_attribute_embedding = pred_attribute_embedding else: selected_attribute_embedding = outp_emotion_embedding dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding], 1) # Decoder if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.da_logits, labels=self.output_emotions) self.avg_da_loss = tf.reduce_mean(da_loss) kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo tf.summary.scalar("da_loss", self.avg_da_loss) tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, sess, config, api, log_dir, forward, scope=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size self.num_topics = config.num_topics with tf.name_scope("io"): # all dialog context and known attributes self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None, self.max_utt_len), name="dialog_context") self.floors = tf.placeholder(dtype=tf.float32, shape=(None, None), name="floor") # TODO float self.floor_labels = tf.placeholder(dtype=tf.float32, shape=(None, 1), name="floor_labels") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") self.paragraph_topics = tf.placeholder(dtype=tf.float32, shape=(None, self.num_topics), name="paragraph_topics") # target response given the dialog context self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") self.output_das = tf.placeholder(dtype=tf.float32, shape=(None, self.num_topics), name="output_dialog_acts") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_dialog_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask # embed the input input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.input_contexts, [-1])) # reshape embedding. -1 means that the first dimension can be whatever necessary to make the other 2 dimensions work w/the data input_embedding = tf.reshape( input_embedding, [-1, self.max_utt_len, config.embed_size]) # embed the output so you can feed it into the VAE output_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) # if config.sent_type == "bow": input_embedding, sent_size = get_bow(input_embedding) output_embedding, _ = get_bow(output_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size]) if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) # reshape floors floor = tf.reshape(self.floors, [-1, max_dialog_len, 1]) joint_embedding = tf.concat([input_embedding, floor], 2, "joint_embedding") with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, joint_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: if config.cell_type == 'lstm': enc_last_state = [temp.h for temp in enc_last_state] enc_last_state = tf.concat(enc_last_state, 1) else: if config.cell_type == 'lstm': enc_last_state = enc_last_state.h # combine with other attributes if config.use_hcf: # TODO is this reshape ok? attribute_embedding = tf.reshape( self.output_das, [-1, self.num_topics]) # da_embedding attribute_fc1 = layers.fully_connected(attribute_embedding, 30, activation_fn=tf.tanh, scope="attribute_fc1") # conditions include topic and rnn of all previous birnn results and metadata about the two people cond_list = [self.paragraph_topics, enc_last_state] cond_embedding = tf.concat(cond_list, 1) #float32 with variable_scope.variable_scope("recognitionNetwork"): if config.use_hcf: recog_input = tf.concat( [cond_embedding, output_embedding, attribute_fc1], 1) else: recog_input = tf.concat([cond_embedding, output_embedding], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") # mu and logvar are both vectors of size latent_size recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): # P(XYZ)=P(Z|X)P(X)P(Y|X,Z) prior_fc1 = layers.fully_connected(cond_embedding, np.maximum( config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) #float32 # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Predicting Y (topic) if config.use_hcf: meta_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="meta_fc1") if config.keep_prob < 1.0: meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob) self.da_logits = layers.fully_connected( meta_fc1, self.num_topics, scope="da_project") # float32 da_prob = tf.nn.softmax(self.da_logits) pred_attribute_embedding = da_prob # TODO change the name of this to predicted sentence topic # pred_attribute_embedding = tf.matmul(da_prob, d_embedding) if forward: selected_attribute_embedding = pred_attribute_embedding else: selected_attribute_embedding = attribute_embedding dec_inputs = tf.concat( [gen_inputs, selected_attribute_embedding], 1) # if use_hcf not on, the model won't predict the Y else: self.da_logits = tf.zeros((batch_size, self.num_topics)) dec_inputs = gen_inputs selected_attribute_embedding = None # Predicting whether or not end of paragraph self.paragraph_end_logits = layers.fully_connected( gen_inputs, 1, activation_fn=tf.tanh, scope="paragraph_end_fc1") # float32 # Decoder if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': # initializer thing for lstm temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) # projects into thing of vocab size. TODO no softmax? dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: # get make of keep/throw-away keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens, name='output_node') if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] # correct word tokens label_mask = tf.to_float(tf.sign(labels)) # Loss between words rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) # BOW loss tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) # Predict 0/1 (1 = last sentence in paragraph) end_loss = tf.nn.softmax_cross_entropy_with_logits( labels=self.floor_labels, logits=self.paragraph_end_logits) self.avg_end_loss = tf.reduce_mean(end_loss) # Topic prediction loss if config.use_hcf: div_prob = tf.divide(self.da_logits, self.output_das) self.avg_da_loss = tf.reduce_mean( -tf.nn.softmax_cross_entropy_with_logits( logits=self.da_logits, labels=div_prob)) else: self.avg_da_loss = 0.0 kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo + self.avg_end_loss tf.summary.scalar("da_loss", self.avg_da_loss) tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("bow_loss", self.avg_bow_loss) tf.summary.scalar("paragraph_end_loss", self.avg_end_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, sess, config, api, log_dir, forward, scope=None, name=None): self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.idf = api.index2idf self.gen_vocab_size = api.gen_vocab_size self.topic_vocab = api.topic_vocab self.topic_vocab_size = len(self.topic_vocab) self.da_vocab = api.dialog_act_vocab self.da_vocab_size = len(self.da_vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.max_per_len = config.max_per_len self.max_per_line = config.max_per_line self.max_per_words = config.max_per_words self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.memory_cell_size = config.memory_cell_size self.dec_cell_size = config.dec_cell_size self.hops = config.hops self.batch_size = config.batch_size self.test_samples = config.test_samples self.balance_factor = config.balance_factor with tf.name_scope("io"): self.first_dimension_size = self.batch_size self.input_contexts = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, None, self.max_utt_len), name="dialog_context") self.floors = tf.placeholder(dtype=tf.int32, shape=(self.first_dimension_size, None), name="floor") self.context_lens = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, ), name="context_lens") self.topics = tf.placeholder(dtype=tf.int32, shape=(self.first_dimension_size, ), name="topics") self.personas = tf.placeholder(dtype=tf.int32, shape=(self.first_dimension_size, self.max_per_line, self.max_per_len), name="personas") self.persona_words = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, self.max_per_line, self.max_per_len), name="persona_words") self.persona_position = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, None), name="persona_position") self.selected_persona = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, 1), name="selected_persona") self.query = tf.placeholder(dtype=tf.int32, shape=(self.first_dimension_size, self.max_utt_len), name="query") # target response given the dialog context self.output_tokens = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, None), name="output_token") self.output_lens = tf.placeholder( dtype=tf.int32, shape=(self.first_dimension_size, ), name="output_lens") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_context_lines = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.input_contexts, [-1])) input_embedding = tf.reshape( input_embedding, [-1, self.max_utt_len, config.embed_size]) output_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) persona_input_embedding = embedding_ops.embedding_lookup( embedding, tf.reshape(self.personas, [-1])) persona_input_embedding = tf.reshape( persona_input_embedding, [-1, self.max_per_len, config.embed_size]) if config.sent_type == "bow": input_embedding, sent_size = get_bow(input_embedding) output_embedding, _ = get_bow(output_embedding) persona_input_embedding, _ = get_bow(persona_input_embedding) elif config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) _, input_embedding, sent_size = get_rnn_encode( input_embedding, sent_cell, scope="sent_rnn") _, output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) _, persona_input_embedding, _ = get_rnn_encode( persona_input_embedding, sent_cell, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_step_embedding, input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn") _, output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) _, persona_input_embedding, _ = get_bi_rnn_encode( persona_input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs input_embedding = tf.reshape(input_embedding, [-1, max_context_lines, sent_size]) self.input_step_embedding = input_step_embedding self.encoder_state_size = sent_size if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) with variable_scope.variable_scope("personaMemory"): embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) A = tf.get_variable("persona_embedding_A", [self.vocab_size, self.memory_cell_size], dtype=tf.float32) A = A * embedding_mask C = [] for hopn in range(self.hops): C.append( tf.get_variable("persona_embedding_C_hop_{}".format(hopn), [self.vocab_size, self.memory_cell_size], dtype=tf.float32) * embedding_mask) q_emb = tf.nn.embedding_lookup(A, self.query) u_0 = tf.reduce_sum(q_emb, 1) u = [u_0] for hopn in range(self.hops): if hopn == 0: m_emb_A = tf.nn.embedding_lookup(A, self.personas) m_A = tf.reshape(m_emb_A, [ -1, self.max_per_len * self.max_per_line, self.memory_cell_size ]) else: with tf.variable_scope('persona_hop_{}'.format(hopn)): m_emb_A = tf.nn.embedding_lookup( C[hopn - 1], self.personas) m_A = tf.reshape(m_emb_A, [ -1, self.max_per_len * self.max_per_line, self.memory_cell_size ]) u_temp = tf.transpose(tf.expand_dims(u[-1], -1), [0, 2, 1]) dotted = tf.reduce_sum(m_A * u_temp, 2) probs = tf.nn.softmax(dotted) probs_temp = tf.transpose(tf.expand_dims(probs, -1), [0, 2, 1]) with tf.variable_scope('persona_hop_{}'.format(hopn)): m_emb_C = tf.nn.embedding_lookup( C[hopn], tf.reshape(self.personas, [-1, self.max_per_len * self.max_per_line])) m_emb_C = tf.expand_dims(m_emb_C, -2) m_C = tf.reduce_sum(m_emb_C, axis=2) c_temp = tf.transpose(m_C, [0, 2, 1]) o_k = tf.reduce_sum(c_temp * probs_temp, axis=2) u_k = u[-1] + o_k u.append(u_k) persona_memory = u[-1] with variable_scope.variable_scope("contextEmbedding"): context_layers = 2 enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=context_layers) _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) if context_layers > 1: if config.cell_type == 'lstm': enc_last_state = [temp.h for temp in enc_last_state] enc_last_state = tf.concat(enc_last_state, 1) else: if config.cell_type == 'lstm': enc_last_state = enc_last_state.h cond_embedding = tf.concat([persona_memory, enc_last_state], 1) with variable_scope.variable_scope("recognitionNetwork"): recog_input = tf.concat( [cond_embedding, output_embedding, persona_memory], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): prior_fc1 = layers.fully_connected(cond_embedding, np.maximum( config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("personaSelecting"): condition = tf.concat([persona_memory, latent_sample], 1) self.persona_dist = tf.nn.log_softmax( layers.fully_connected(condition, self.max_per_line, activation_fn=tf.tanh, scope="persona_dist")) select_temp = tf.expand_dims( tf.argmax(self.persona_dist, 1, output_type=tf.int32), 1) index_temp = tf.expand_dims( tf.range(0, self.first_dimension_size, dtype=tf.int32), 1) persona_select = tf.concat([index_temp, select_temp], 1) selected_words_ordered = tf.reshape( tf.gather_nd(self.persona_words, persona_select), [self.max_per_len * self.first_dimension_size]) self.selected_words = tf.gather_nd(self.persona_words, persona_select) label = tf.reshape( selected_words_ordered, [self.max_per_len * self.first_dimension_size, 1]) index = tf.reshape( tf.range(self.first_dimension_size, dtype=tf.int32), [self.first_dimension_size, 1]) index = tf.reshape( tf.tile(index, [1, self.max_per_len]), [self.max_per_len * self.first_dimension_size, 1]) concated = tf.concat([index, label], 1) true_labels = tf.where(selected_words_ordered > 0) concated = tf.gather_nd(concated, true_labels) self.persona_word_mask = tf.sparse_to_dense( concated, [self.first_dimension_size, self.vocab_size], config.perw_weight, 0.0) self.other_word_mask = tf.sparse_to_dense( concated, [self.first_dimension_size, self.vocab_size], 0.0, config.othw_weight) self.persona_word_mask = self.persona_word_mask * self.idf with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([cond_embedding, latent_sample], 1) # BOW loss bow_fc1 = layers.fully_connected(gen_inputs, 400, activation_fn=tf.tanh, scope="bow_fc1") if config.keep_prob < 1.0: bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob) self.bow_logits = layers.fully_connected(bow_fc1, self.vocab_size, activation_fn=None, scope="bow_project") # Y loss dec_inputs = gen_inputs selected_attribute_embedding = None self.da_logits = tf.zeros((batch_size, self.da_vocab_size)) # Decoder if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) pos_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) pos_cell = OutputProjectionWrapper(pos_cell, self.vocab_size) with variable_scope.variable_scope("position"): self.pos_w_1 = tf.get_variable("pos_w_1", [self.dec_cell_size, 2], dtype=tf.float32) self.pos_b_1 = tf.get_variable("pos_b_1", [2], dtype=tf.float32) def position_function(states, logp=False): states = tf.reshape(states, [-1, self.dec_cell_size]) if logp: return tf.reshape( tf.nn.log_softmax( tf.matmul(states, self.pos_w_1) + self.pos_b_1), [self.first_dimension_size, -1, 2]) return tf.reshape( tf.nn.softmax( tf.matmul(states, self.pos_w_1) + self.pos_b_1), [self.first_dimension_size, -1, 2]) if forward: loop_func = self.context_decoder_fn_inference( position_function, self.persona_word_mask, self.other_word_mask, None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding, ) dec_input_embedding = None dec_seq_lens = None else: loop_func = self.context_decoder_fn_train( dec_init_state, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) with variable_scope.variable_scope("dec_state"): dec_outs, _, final_context_state, rnn_states = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) with variable_scope.variable_scope("pos_state"): _, _, _, pos_states = dynamic_rnn_decoder( pos_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) self.position_dist = position_function(pos_states, logp=True) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) per_select_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tf.reshape(self.persona_dist, [self.first_dimension_size, 1, -1]), labels=self.selected_persona) per_select_loss = tf.reduce_sum(per_select_loss, reduction_indices=1) self.avg_per_select_loss = tf.reduce_mean(per_select_loss) position_labels = self.persona_position[:, 1:] per_pos_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.position_dist, labels=position_labels) per_pos_loss = tf.reduce_sum(per_pos_loss, reduction_indices=1) self.avg_per_pos_loss = tf.reduce_mean(per_pos_loss) tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1), [1, max_out_len - 1, 1]) bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tile_bow_logits, labels=labels) * label_mask bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1) self.avg_bow_loss = tf.reduce_mean(bow_loss) kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld aug_elbo = self.elbo + self.avg_bow_loss + 0.1 * self.avg_per_select_loss + 0.05 * self.avg_per_pos_loss tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) tf.summary.scalar("per_pos_loss", self.avg_per_pos_loss) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss + bow_loss - self.log_p_z + self.log_q_z_xy) self.optimize(sess, config, aug_elbo, log_dir) self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
def __init__(self, sess, config, api, log_dir, forward, scope=None): # self.self_label = tf.placeholder(dtype=tf.bool,shape=(None), name="self_label") self.self_label = False self.vocab = api.vocab self.rev_vocab = api.rev_vocab self.vocab_size = len(self.vocab) self.sess = sess self.scope = scope self.max_utt_len = config.max_utt_len self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.context_cell_size = config.cxt_cell_size self.sent_cell_size = config.sent_cell_size self.dec_cell_size = config.dec_cell_size with tf.name_scope("io"): self.input_contexts = tf.placeholder(dtype=tf.int32, shape=(None, None), name="dialog_context") self.context_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="context_lens") self.output_tokens = tf.placeholder(dtype=tf.int32, shape=(None, None), name="output_token") self.output_lens = tf.placeholder(dtype=tf.int32, shape=(None, ), name="output_lens") # optimization related variables self.learning_rate = tf.Variable(float(config.init_lr), trainable=False, name="learning_rate") self.learning_rate_decay_op = self.learning_rate.assign( tf.multiply(self.learning_rate, config.lr_decay)) self.global_t = tf.placeholder(dtype=tf.int32, name="global_t") self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") max_input_len = array_ops.shape(self.input_contexts)[1] max_out_len = array_ops.shape(self.output_tokens)[1] batch_size = array_ops.shape(self.input_contexts)[0] with variable_scope.variable_scope("wordEmbedding"): self.embedding = tf.get_variable( "embedding", [self.vocab_size, config.embed_size], dtype=tf.float32) embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) embedding = self.embedding * embedding_mask input_embedding = embedding_ops.embedding_lookup( embedding, self.input_contexts) output_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) if config.sent_type == "rnn": sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn") output_embedding, _ = get_rnn_encode(output_embedding, sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) input_embedding, sent_size = get_bi_rnn_encode( input_embedding, fwd_sent_cell, bwd_sent_cell, self.context_lens, scope="sent_bi_rnn") output_embedding, _ = get_bi_rnn_encode(output_embedding, fwd_sent_cell, bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs if config.keep_prob < 1.0: input_embedding = tf.nn.dropout(input_embedding, config.keep_prob) with variable_scope.variable_scope("contextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state input_embedding = tf.expand_dims(input_embedding, axis=2) _, enc_last_state = tf.nn.dynamic_rnn( enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: if config.cell_type == 'lstm': enc_last_state = [temp.h for temp in enc_last_state] enc_last_state = tf.concat(enc_last_state, 1) else: if config.cell_type == 'lstm': enc_last_state = enc_last_state.h # input [enc_last_state, output_embedding] -- [c, x] --->z with variable_scope.variable_scope("recognitionNetwork"): recog_input = tf.concat([enc_last_state, output_embedding], 1) self.recog_mulogvar = recog_mulogvar = layers.fully_connected( recog_input, config.latent_size * 2, activation_fn=None, scope="muvar") recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1) with variable_scope.variable_scope("priorNetwork"): # P(XYZ)=P(Z|X)P(X)P(Y|X,Z) prior_fc1 = layers.fully_connected(enc_last_state, np.maximum( config.latent_size * 2, 100), activation_fn=tf.tanh, scope="fc1") prior_mulogvar = layers.fully_connected(prior_fc1, config.latent_size * 2, activation_fn=None, scope="muvar") prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1) # use sampled Z or posterior Z latent_sample = tf.cond( self.use_prior, lambda: sample_gaussian(prior_mu, prior_logvar), lambda: sample_gaussian(recog_mu, recog_logvar)) with variable_scope.variable_scope("label_encoder"): le_embedding_mask = tf.constant( [0 if i == 0 else 1 for i in range(self.vocab_size)], dtype=tf.float32, shape=[self.vocab_size, 1]) le_embedding = self.embedding * le_embedding_mask le_input_embedding = embedding_ops.embedding_lookup( le_embedding, self.input_contexts) le_output_embedding = embedding_ops.embedding_lookup( le_embedding, self.output_tokens) if config.sent_type == "rnn": le_sent_cell = self.get_rnncell("gru", self.sent_cell_size, config.keep_prob, 1) le_input_embedding, le_sent_size = get_rnn_encode( le_input_embedding, le_sent_cell, scope="sent_rnn") le_output_embedding, _ = get_rnn_encode(le_output_embedding, le_sent_cell, self.output_lens, scope="sent_rnn", reuse=True) elif config.sent_type == "bi_rnn": le_fwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) le_bwd_sent_cell = self.get_rnncell("gru", self.sent_cell_size, keep_prob=1.0, num_layer=1) le_input_embedding, le_sent_size = get_bi_rnn_encode( le_input_embedding, le_fwd_sent_cell, le_bwd_sent_cell, self.context_lens, scope="sent_bi_rnn") le_output_embedding, _ = get_bi_rnn_encode(le_output_embedding, le_fwd_sent_cell, le_bwd_sent_cell, self.output_lens, scope="sent_bi_rnn", reuse=True) else: raise ValueError( "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]") # reshape input into dialogs if config.keep_prob < 1.0: le_input_embedding = tf.nn.dropout(le_input_embedding, config.keep_prob) # [le_enc_last_state, le_output_embedding] with variable_scope.variable_scope("lecontextRNN"): enc_cell = self.get_rnncell(config.cell_type, self.context_cell_size, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state le_input_embedding = tf.expand_dims(le_input_embedding, axis=2) _, le_enc_last_state = tf.nn.dynamic_rnn( enc_cell, le_input_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: if config.cell_type == 'lstm': le_enc_last_state = [temp.h for temp in le_enc_last_state] le_enc_last_state = tf.concat(le_enc_last_state, 1) else: if config.cell_type == 'lstm': le_enc_last_state = le_enc_last_state.h best_en = tf.concat([le_enc_last_state, le_output_embedding], 1) with variable_scope.variable_scope("ggammaNet"): enc_cell = self.get_rnncell(config.cell_type, 200, keep_prob=1.0, num_layer=config.num_layer) # and enc_last_state will be same as the true last state input_embedding = tf.expand_dims(best_en, axis=2) _, zlabel = tf.nn.dynamic_rnn(enc_cell, input_embedding, dtype=tf.float32, sequence_length=self.context_lens) if config.num_layer > 1: if config.cell_type == 'lstm': zlabel = [temp.h for temp in enc_last_state] zlabel = tf.concat(zlabel, 1) else: if config.cell_type == 'lstm': zlabel = zlabel.h with variable_scope.variable_scope("generationNetwork"): gen_inputs = tf.concat([enc_last_state, latent_sample], 1) dec_inputs = gen_inputs selected_attribute_embedding = None # Decoder_init_state if config.num_layer > 1: dec_init_state = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state.append(temp_init) dec_init_state = tuple(dec_init_state) else: dec_init_state = layers.fully_connected(dec_inputs, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state = rnn_cell.LSTMStateTuple( dec_init_state, dec_init_state) with variable_scope.variable_scope("generationNetwork1"): gen_inputs_sl = tf.concat([le_enc_last_state, zlabel], 1) dec_inputs_sl = gen_inputs_sl selected_attribute_embedding = None # Decoder_init_state if config.num_layer > 1: dec_init_state_sl = [] for i in range(config.num_layer): temp_init = layers.fully_connected(dec_inputs_sl, self.dec_cell_size, activation_fn=None, scope="init_state-%d" % i) if config.cell_type == 'lstm': temp_init = rnn_cell.LSTMStateTuple( temp_init, temp_init) dec_init_state_sl.append(temp_init) dec_init_state_sl = tuple(dec_init_state_sl) else: dec_init_state_sl = layers.fully_connected(dec_inputs_sl, self.dec_cell_size, activation_fn=None, scope="init_state") if config.cell_type == 'lstm': dec_init_state_sl = rnn_cell.LSTMStateTuple( dec_init_state_sl, dec_init_state_sl) with variable_scope.variable_scope("decoder"): dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size, config.keep_prob, config.num_layer) dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size) if forward: loop_func = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state, embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) loop_func_sl = decoder_fn_lib.context_decoder_fn_inference( None, dec_init_state_sl, le_embedding, start_of_sequence_id=self.go_id, end_of_sequence_id=self.eos_id, maximum_length=self.max_utt_len, num_decoder_symbols=self.vocab_size, context_vector=selected_attribute_embedding) dec_input_embedding = None dec_input_embedding_sl = None dec_seq_lens = None else: loop_func = decoder_fn_lib.context_decoder_fn_train( dec_init_state, selected_attribute_embedding) loop_func_sl = decoder_fn_lib.context_decoder_fn_train( dec_init_state_sl, selected_attribute_embedding) dec_input_embedding = embedding_ops.embedding_lookup( embedding, self.output_tokens) dec_input_embedding_sl = embedding_ops.embedding_lookup( le_embedding, self.output_tokens) dec_input_embedding = dec_input_embedding[:, 0:-1, :] dec_input_embedding_sl = dec_input_embedding_sl[:, 0:-1, :] dec_seq_lens = self.output_lens - 1 if config.keep_prob < 1.0: dec_input_embedding = tf.nn.dropout( dec_input_embedding, config.keep_prob) dec_input_embedding_sl = tf.nn.dropout( dec_input_embedding_sl, config.keep_prob) # apply word dropping. Set dropped word to 0 if config.dec_keep_prob < 1.0: keep_mask = tf.less_equal( tf.random_uniform((batch_size, max_out_len - 1), minval=0.0, maxval=1.0), config.dec_keep_prob) keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2) dec_input_embedding = dec_input_embedding * keep_mask dec_input_embedding_sl = dec_input_embedding_sl * keep_mask dec_input_embedding = tf.reshape( dec_input_embedding, [-1, max_out_len - 1, config.embed_size]) dec_input_embedding_sl = tf.reshape( dec_input_embedding_sl, [-1, max_out_len - 1, config.embed_size]) dec_outs, _, final_context_state = dynamic_rnn_decoder( dec_cell, loop_func, inputs=dec_input_embedding, sequence_length=dec_seq_lens) dec_outs_sl, _, final_context_state_sl = dynamic_rnn_decoder( dec_cell, loop_func_sl, inputs=dec_input_embedding_sl, sequence_length=dec_seq_lens) if final_context_state is not None: final_context_state = final_context_state[:, 0:array_ops. shape(dec_outs)[1]] mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2))) self.dec_out_words = tf.multiply( tf.reverse(final_context_state, axis=[1]), mask) else: self.dec_out_words = tf.argmax(dec_outs, 2) if final_context_state_sl is not None: final_context_state_sl = final_context_state_sl[:, 0:array_ops. shape( dec_outs_sl )[1]] mask_sl = tf.to_int32( tf.sign(tf.reduce_max(dec_outs_sl, axis=2))) self.dec_out_words_sl = tf.multiply( tf.reverse(final_context_state_sl, axis=[1]), mask_sl) else: self.dec_out_words_sl = tf.argmax(dec_outs_sl, 2) if not forward: with variable_scope.variable_scope("loss"): labels = self.output_tokens[:, 1:] label_mask = tf.to_float(tf.sign(labels)) rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs, labels=labels) rc_loss = tf.reduce_sum(rc_loss * label_mask, reduction_indices=1) self.avg_rc_loss = tf.reduce_mean(rc_loss) sl_rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_outs_sl, labels=labels) sl_rc_loss = tf.reduce_sum(sl_rc_loss * label_mask, reduction_indices=1) self.sl_rc_loss = tf.reduce_mean(sl_rc_loss) # used only for perpliexty calculation. Not used for optimzation self.rc_ppl = tf.exp( tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask)) """ as n-trial multimodal distribution. """ kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar) self.avg_kld = tf.reduce_mean(kld) if log_dir is not None: kl_weights = tf.minimum( tf.to_float(self.global_t) / config.full_kl_step, 1.0) else: kl_weights = tf.constant(1.0) self.label_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=latent_sample, logits=zlabel)) self.kl_w = kl_weights self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld self.cvae_loss = self.elbo + +0.1 * self.label_loss self.sl_loss = self.sl_rc_loss tf.summary.scalar("rc_loss", self.avg_rc_loss) tf.summary.scalar("elbo", self.elbo) tf.summary.scalar("kld", self.avg_kld) self.summary_op = tf.summary.merge_all() self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar) self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar) self.est_marginal = tf.reduce_mean(rc_loss - self.log_p_z + self.log_q_z_xy) self.train_sl_ops = self.optimize(sess, config, self.sl_loss, log_dir, scope="SL") self.train_ops = self.optimize(sess, config, self.cvae_loss, log_dir, scope="CVAE") self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)