def build(self): def get_topic_loss_reg(topic_embeddings): topic_embeddings_norm = topic_embeddings / tf.norm(topic_embeddings, axis=1, keepdims=True) self.topic_dots = tf.clip_by_value(tf.matmul(topic_embeddings_norm, tf.transpose(topic_embeddings_norm)), -1., 1.) topic_loss_reg = tf.reduce_mean(tf.square(self.topic_dots - tf.eye(self.config.n_topic))) return topic_loss_reg # -------------- Build Model -------------- tf.reset_default_graph() tf.set_random_seed(self.config.seed) self.t_variables['bow'] = tf.placeholder(tf.float32, [None, self.config.dim_bow]) self.t_variables['keep_prob'] = tf.placeholder(tf.float32) # encode bow with tf.variable_scope('topic/enc', reuse=False): hidden_bow_ = tf.layers.Dense(units=self.config.dim_hidden_bow, activation=tf.nn.tanh, name='hidden_bow')(self.t_variables['bow']) hidden_bow = tf.layers.Dropout(self.t_variables['keep_prob'])(hidden_bow_) means_bow = tf.layers.Dense(units=self.config.dim_latent_bow, name='mean_bow')(hidden_bow) logvars_bow = tf.layers.Dense(units=self.config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow) latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors self.prob_topic = tf.layers.Dense(units=self.config.n_topic, activation=tf.nn.softmax, name='prob_topic')(latents_bow) # inference of topic probabilities # decode bow with tf.variable_scope('shared', reuse=False): self.bow_embeddings = tf.get_variable('emb', [self.config.dim_bow, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(), trainable=self.config.train_emb) # embeddings of vocab with tf.variable_scope('topic/dec', reuse=False): self.topic_embeddings = tf.get_variable('topic_emb', [self.config.n_topic, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of topics self.topic_bow = tf.nn.softmax(tf.matmul(self.topic_embeddings, self.bow_embeddings, transpose_b=True), 1) # bow vectors for each topic self.logits_bow = tf_log(tf.matmul(self.prob_topic, self.topic_bow)) # predicted bow distribution N_Batch x V # define losses self.topic_losses_recon = -tf.reduce_sum(tf.multiply(self.t_variables['bow'], self.logits_bow), 1) self.topic_loss_recon = tf.reduce_mean(self.topic_losses_recon) # negative log likelihood of each words self.topic_losses_kl = compute_kl_loss(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std self.topic_loss_kl = tf.reduce_mean(self.topic_losses_kl, 0) #mean of kl_losses over batches self.topic_loss_reg = get_topic_loss_reg(self.topic_embeddings) self.loss = self.topic_loss_recon + self.topic_loss_kl + self.config.reg * self.topic_loss_reg # define optimizer if self.config.opt == 'Adam': optimizer = tf.train.AdamOptimizer(self.config.lr) elif self.config.opt == 'Adagrad': optimizer = tf.train.AdagradOptimizer(self.config.lr) self.grad_vars = optimizer.compute_gradients(self.loss) self.clipped_grad_vars = [(tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in self.grad_vars] self.global_step = tf.Variable(0, name='global_step',trainable=False) self.opt = optimizer.apply_gradients(self.clipped_grad_vars, global_step=self.global_step) # monitor self.n_bow = tf.reduce_sum(self.t_variables['bow'], 1) self.topic_ppls = tf.divide(self.topic_losses_recon + self.topic_losses_kl, tf.maximum(1e-5, self.n_bow)) # growth criteria self.n_topics = tf.multiply(tf.expand_dims(self.n_bow, -1), self.prob_topic)
def build(self): def get_prob_topic(tree_prob_leaf, prob_depth): tree_prob_topic = defaultdict(float) leaf_ancestor_idxs = {leaf_idx: get_ancestor_idxs(leaf_idx, self.child_to_parent_idxs) for leaf_idx in tree_prob_leaf} for leaf_idx, ancestor_idxs in leaf_ancestor_idxs.items(): prob_leaf = tree_prob_leaf[leaf_idx] for i, ancestor_idx in enumerate(ancestor_idxs): prob_ancestor = prob_leaf * tf.expand_dims(prob_depth[:, i], -1) tree_prob_topic[ancestor_idx] += prob_ancestor prob_topic = tf.concat([tree_prob_topic[topic_idx] for topic_idx in self.topic_idxs], -1) return prob_topic def get_tree_topic_bow(tree_topic_embeddings): tree_topic_bow = {} for topic_idx, depth in self.tree_depth.items(): topic_embedding = tree_topic_embeddings[topic_idx] temperature = tf.constant(self.config.depth_temperature ** (1./depth), dtype=tf.float32) logits = tf.matmul(topic_embedding, self.bow_embeddings, transpose_b=True) tree_topic_bow[topic_idx] = softmax_with_temperature(logits, axis=-1, temperature=temperature) return tree_topic_bow def get_topic_loss_reg(tree_topic_embeddings): def get_tree_mask_reg(all_child_idxs): tree_mask_reg = np.zeros([len(all_child_idxs), len(all_child_idxs)], dtype=np.float32) for parent_idx, child_idxs in self.tree_idxs.items(): neighbor_idxs = child_idxs for neighbor_idx1 in neighbor_idxs: for neighbor_idx2 in neighbor_idxs: neighbor_index1 = all_child_idxs.index(neighbor_idx1) neighbor_index2 = all_child_idxs.index(neighbor_idx2) tree_mask_reg[neighbor_index1, neighbor_index2] = tree_mask_reg[neighbor_index2, neighbor_index1] = 1. return tree_mask_reg all_child_idxs = list(self.child_to_parent_idxs.keys()) self.diff_topic_embeddings = tf.concat([tree_topic_embeddings[child_idx] - tree_topic_embeddings[self.child_to_parent_idxs[child_idx]] for child_idx in all_child_idxs], axis=0) diff_topic_embeddings_norm = self.diff_topic_embeddings / tf.norm(self.diff_topic_embeddings, axis=1, keepdims=True) self.topic_dots = tf.clip_by_value(tf.matmul(diff_topic_embeddings_norm, tf.transpose(diff_topic_embeddings_norm)), -1., 1.) self.tree_mask_reg = get_tree_mask_reg(all_child_idxs) self.topic_losses_reg = tf.square(self.topic_dots - tf.eye(len(all_child_idxs))) * self.tree_mask_reg self.topic_loss_reg = tf.reduce_sum(self.topic_losses_reg) / tf.reduce_sum(self.tree_mask_reg) return self.topic_loss_reg # -------------- Build Model -------------- tf.reset_default_graph() tf.set_random_seed(self.config.seed) self.t_variables['bow'] = tf.placeholder(tf.float32, [None, self.config.dim_bow]) self.t_variables['keep_prob'] = tf.placeholder(tf.float32) # encode bow with tf.variable_scope('topic/enc', reuse=False): hidden_bow_ = tf.layers.Dense(units=self.config.dim_hidden_bow, activation=tf.nn.tanh, name='hidden_bow')(self.t_variables['bow']) hidden_bow = tf.layers.Dropout(self.t_variables['keep_prob'])(hidden_bow_) means_bow = tf.layers.Dense(units=self.config.dim_latent_bow, name='mean_bow')(hidden_bow) logvars_bow = tf.layers.Dense(units=self.config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow) latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors prob_layer = lambda h: tf.nn.sigmoid(tf.matmul(latents_bow, h, transpose_b=True)) tree_sticks_topic, tree_states_sticks_topic = doubly_rnn(self.config.dim_latent_bow, self.tree_idxs, output_layer=prob_layer, name='sticks_topic') self.tree_prob_leaf = tsbp(tree_sticks_topic, self.tree_idxs) sticks_depth, _ = rnn(self.config.dim_latent_bow, self.n_depth, output_layer=prob_layer, name='prob_depth') self.prob_depth = sbp(sticks_depth, self.n_depth) self.prob_topic = get_prob_topic(self.tree_prob_leaf, self.prob_depth)# n_batch x n_topic # decode bow with tf.variable_scope('shared', reuse=False): self.bow_embeddings = tf.get_variable('emb', [self.config.dim_bow, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of vocab with tf.variable_scope('topic/dec', reuse=False): emb_layer = lambda h: tf.layers.Dense(units=self.config.dim_emb, name='output')(tf.nn.tanh(h)) self.tree_topic_embeddings, tree_states_topic_embeddings = doubly_rnn(self.config.dim_emb, self.tree_idxs, output_layer=emb_layer, name='emb_topic') self.tree_topic_bow = get_tree_topic_bow(self.tree_topic_embeddings) # bow vectors for each topic self.topic_bow = tf.concat([self.tree_topic_bow[topic_idx] for topic_idx in self.topic_idxs], 0) # KxV self.logits_bow = tf_log(tf.matmul(self.prob_topic, self.topic_bow)) # predicted bow distribution N_Batch x V # define losses self.topic_losses_recon = -tf.reduce_sum(tf.multiply(self.t_variables['bow'], self.logits_bow), 1) self.topic_loss_recon = tf.reduce_mean(self.topic_losses_recon) # negative log likelihood of each words self.topic_losses_kl = compute_kl_losses(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std self.topic_loss_kl = tf.reduce_mean(self.topic_losses_kl, 0) #mean of kl_losses over batches self.topic_embeddings = tf.concat([self.tree_topic_embeddings[topic_idx] for topic_idx in self.topic_idxs], 0) # temporary self.topic_loss_reg = get_topic_loss_reg(self.tree_topic_embeddings) self.global_step = tf.Variable(0, name='global_step',trainable=False) self.loss = self.topic_loss_recon + self.topic_loss_kl + self.config.reg * self.topic_loss_reg # define optimizer if self.config.opt == 'Adam': optimizer = tf.train.AdamOptimizer(self.config.lr) elif self.config.opt == 'Adagrad': optimizer = tf.train.AdagradOptimizer(self.config.lr) self.grad_vars = optimizer.compute_gradients(self.loss) self.clipped_grad_vars = [(tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in self.grad_vars] self.opt = optimizer.apply_gradients(self.clipped_grad_vars, global_step=self.global_step) # monitor self.n_bow = tf.reduce_sum(self.t_variables['bow'], 1) self.topic_ppls = tf.divide(self.topic_losses_recon + self.topic_losses_kl, tf.maximum(1e-5, self.n_bow)) # growth criteria self.n_topics = tf.multiply(tf.expand_dims(self.n_bow, -1), self.prob_topic) self.arcs_bow = tf.acos(tf.matmul(tf.linalg.l2_normalize(self.bow_embeddings, axis=-1), tf.linalg.l2_normalize(self.topic_embeddings, axis=-1), transpose_b=True)) # n_vocab x n_topic self.rads_bow = tf.multiply(tf.matmul(self.t_variables['bow'], self.arcs_bow), self.prob_topic) # n_batch x n_topic
def build(self): self.global_step = tf.Variable(0, name='global_step', trainable=False) self.softmax_temperature = tf.maximum( \ self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \ self.config.min_temperature) with tf.name_scope('t_variables'): self.sample = self.t_variables['sample'] self.batch_l = self.t_variables['batch_l'] self.doc_l = self.t_variables['doc_l'] self.sent_l = self.t_variables['sent_l'] self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.max_doc_l = tf.reduce_max(self.doc_l) self.max_sent_l = tf.reduce_max(self.sent_l) self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32) self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32) mask_bow = np.zeros(self.config.n_vocab) mask_bow[self.config.bow_idxs] = 1. self.mask_bow = tf.constant(mask_bow, dtype=tf.float32) self.enc_keep_prob = self.t_variables['enc_keep_prob'] # ------------------------------Encoder ------------------------------ with tf.variable_scope('emb'): with tf.variable_scope('word', reuse=False): pad_embedding = tf.zeros([1, self.config.dim_emb], dtype=tf.float32) nonpad_embeddings = tf.get_variable('emb', [self.config.n_vocab-1, self.config.dim_emb], dtype=tf.float32, \ initializer=tf.contrib.layers.xavier_initializer()) self.embeddings = tf.concat([pad_embedding, nonpad_embeddings], 0) # n_vocab x dim_emb self.bow_embeddings = tf.nn.embedding_lookup( self.embeddings, self.config.bow_idxs) # dim_bow x dim_emb # get sentence embeddings self.enc_input_idxs = tf.one_hot( self.t_variables['enc_input_idxs'], depth=self.config.n_vocab ) # batch_l x max_doc_l x max_sent_l x n_vocab self.enc_inputs = tf.tensordot( self.enc_input_idxs, self.embeddings, axes=[[-1], [0]]) # batch_l x max_doc_l x max_sent_l x dim_emb with tf.variable_scope('sent', reuse=False): self.sent_outputs, self.sent_state = \ encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2 # get sentence latents with tf.variable_scope('latents_sent', reuse=False): self.latents_sent_posterior, self.means_sent_posterior, self.logvars_sent_posterior = \ encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \ config=self.config, name='sent_posterior', min_logvar=self.config.min_logvar) # batch_l x max_doc_l x dim_latent # ------------------------------Discriminator------------------------------ with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=False): # encode by TSNTM self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \ encode_nhdp_probs_topic_posterior(self, self.sent_state.get_shape()[-1], self.sent_state, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic with tf.name_scope('latents_topic'): # get topic sentence posterior distribution for each document self.probs_topic_posterior = tf.reduce_sum( self.probs_sent_topic_posterior, 1) # batch_l x n_topic self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \ tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \ tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent self.means_topic_posterior = tf_clip_means( self.means_topic_posterior_, self.probs_topic_posterior) diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \ tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \ tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \ tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \ tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior = tf_clip_covs( self.covs_topic_posterior_, self.probs_topic_posterior) self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \ seed=self.config.seed, sample=self.sample) if not self.config.prior_root: self.mean_root_prior = tf.zeros( [self.batch_l, self.config.dim_latent], dtype=tf.float32) self.cov_root_prior = tf.eye( self.config.dim_latent, batch_shape=[self.batch_l], dtype=tf.float32) * self.config.cov_root self.means_topic_prior = get_params_topic_prior( self, self.means_topic_posterior, self.mean_root_prior) # batch_l x n_topic x dim_latent self.covs_topic_prior = get_params_topic_prior( self, self.covs_topic_posterior, self.cov_root_prior ) # batch_l x n_topic x dim_latent x dim_latent # ------------------------------Decoder---------------------------------- with tf.variable_scope('dec'): # decode for training sent with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=False): self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden) self.dec_cell = tf.contrib.rnn.DropoutWrapper( self.dec_cell, output_keep_prob=self.t_variables['dec_keep_prob']) self.dec_sent_cell = self.dec_cell self.latent_hidden_layer = tf.layers.Dense( units=self.config.dim_hidden, activation=tf.nn.relu, name='latent_hidden_linear') self.dec_sent_initial_state = self.latent_hidden_layer( self.latents_sent_posterior ) # batch_l x max_doc_l x dim_hidden self.output_layer = tf.layers.Dense(self.config.n_vocab, use_bias=False, name='out') if self.config.attention: self.sent_outputs_flat = tf.reshape( self.sent_outputs, [ self.batch_l * self.max_doc_l, self.max_sent_l, self.config.dim_hidden * 2 ]) self.att_sent_l_flat = tf.reshape( tf.maximum(self.sent_l, 1), [self.batch_l * self.max_doc_l]) self.att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.config.dim_hidden, memory=self.sent_outputs_flat, \ memory_sequence_length=self.att_sent_l_flat) self.att_cell = tf.contrib.seq2seq.AttentionWrapper( self.dec_cell, attention_mechanism=self.att_sent_mechanism, attention_layer_size=self.config.dim_hidden) self.dec_sent_cell = self.att_cell # teacher forcing self.dec_input_idxs = self.t_variables[ 'dec_input_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_inputs = tf.nn.embedding_lookup( self.embeddings, self.dec_input_idxs ) # batch_l x max_doc_l x max_dec_sent_l x dim_emb # output_sent_l == dec_sent_l self.output_logits_flat, self.output_sent_l_flat = decode_output_logits_flat( self, dec_cell=self.dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, dec_inputs=self.dec_inputs, dec_sent_l=self.dec_sent_l, latents_input=self.latents_sent_posterior ) # batch_l*max_doc_l x max_output_sent_l x n_vocab self.output_sent_l = tf.reshape(self.output_sent_l_flat, [self.batch_l, self.max_doc_l]) self.max_output_sent_l = tf.reduce_max(self.output_sent_l) self.output_logits = tf.reshape(self.output_logits_flat, \ [self.batch_l, self.max_doc_l, self.max_output_sent_l, self.config.n_vocab], name='output_logits') if self.config.disc_gumbel: self.output_input_idxs = sample_gumbels( self.output_logits, self.softmax_temperature, self.config.seed, self.sample ) # batch_l x max_doc_l x max_output_sent_l x n_vocab else: self.output_input_idxs = self.output_logits # decode for training topic probs with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.dec_topic_cell = self.dec_cell if self.config.attention: self.topic_outputs_flat = tf.contrib.seq2seq.tile_batch(tf.reshape(self.sent_outputs, \ [self.batch_l, self.max_doc_l*self.max_sent_l, self.sent_outputs.get_shape()[-1]]), \ multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l x dim_hidden*2 self.score_mask = tf.contrib.seq2seq.tile_batch(tf.reshape(tf.sequence_mask(self.sent_l), \ [self.batch_l, self.max_doc_l*self.max_sent_l]), multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l self.hier_score = tf.reshape(tf.transpose(self.probs_sent_topic_posterior, [0, 2, 1]), \ [self.batch_l*self.config.n_topic, self.max_doc_l]) # batch_l*n_topic x max_doc_l self.att_topic_mechanism = HierarchicalAttention( num_units=self.config.dim_hidden, memory=self.topic_outputs_flat, score_mask=self.score_mask, hier_score=self.hier_score) self.att_topic_cell = AttentionWrapper( self.dec_cell, attention_mechanism=self.att_topic_mechanism, attention_layer_size=self.config.dim_hidden) self.dec_topic_cell = self.att_topic_cell if not self.config.disc_mean: self.dec_topic_initial_state = self.latent_hidden_layer( self.latents_topic_posterior) dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.latents_topic_posterior ) # batch_l*max_doc_l x max_summary_sent_l x n_vocab else: self.dec_topic_initial_state = self.latent_hidden_layer( self.means_topic_posterior) dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.means_topic_posterior ) # batch_l*max_doc_l x max_summary_sent_l x n_vocab self.summary_sent_l = tf.reshape( self.summary_sent_l_flat, [self.batch_l, self.config.n_topic]) self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l) if self.config.disc_gumbel: summary_input_idxs_flat = dec_topic_outputs.sample_id else: summary_input_idxs_flat = dec_topic_outputs.rnn_output self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \ [self.batch_l, self.config.n_topic, self.max_summary_sent_l, self.config.n_vocab], name='summary_input_idxs') # re-encode topic sentence outputs self.summary_inputs = tf.tensordot( self.summary_input_idxs, self.embeddings, axes=[[-1], [ 0 ]]) # batch_l x n_topic x max_summary_sent_l x dim_emb self.summary_input_sent_l = self.summary_sent_l - 1 # to remove EOS self.mask_summary_sent = tf.sequence_mask(self.summary_input_sent_l, \ maxlen=self.max_summary_sent_l, dtype=tf.float32) # batch_l x n_topic x max_summary_sent_l self.mask_summary_doc = tf.ones( [self.batch_l, self.config.n_topic], dtype=tf.float32) # beam decode for inference of original sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.beam_dec_sent_cell = self.dec_cell if self.config.attention: self.beam_sent_outputs_flat = tf.contrib.seq2seq.tile_batch( self.sent_outputs_flat, multiplier=self.config.beam_width) self.beam_att_sent_l_flat = tf.contrib.seq2seq.tile_batch( self.att_sent_l_flat, multiplier=self.config.beam_width) self.beam_att_sent_mechanism = tf.contrib.seq2seq.LuongAttention( num_units=self.config.dim_hidden, memory=self.beam_sent_outputs_flat, memory_sequence_length=self.beam_att_sent_l_flat) self.beam_dec_sent_cell = tf.contrib.seq2seq.AttentionWrapper( self.beam_dec_sent_cell, attention_mechanism=self.beam_att_sent_mechanism, attention_layer_size=self.config.dim_hidden) # infer original sentences self.beam_output_idxs, _, _ = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, latents_input=self.means_sent_posterior, name='beam_output_idxs') # beam decode for inference of topic sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.beam_dec_topic_cell = self.dec_cell if self.config.attention: self.beam_topic_outputs_flat = tf.contrib.seq2seq.tile_batch( self.topic_outputs_flat, multiplier=self.config.beam_width) self.beam_score_mask = tf.contrib.seq2seq.tile_batch( self.score_mask, multiplier=self.config.beam_width) self.beam_hier_score = tf.contrib.seq2seq.tile_batch( self.hier_score, multiplier=self.config.beam_width) self.beam_att_topic_mechanism = HierarchicalAttention( num_units=self.config.dim_hidden, memory=self.beam_topic_outputs_flat, score_mask=self.beam_score_mask, hier_score=self.beam_hier_score) self.beam_dec_topic_cell = AttentionWrapper( self.beam_dec_topic_cell, attention_mechanism=self.beam_att_topic_mechanism, attention_layer_size=self.config.dim_hidden) # infer topic sentences self.beam_summary_idxs, _, _ = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, latents_input=self.latents_topic_posterior, name='beam_summary_idxs') self.beam_mask_summary_sent = tf.logical_not(tf.equal(self.beam_summary_idxs, \ self.config.EOS_IDX)) # batch_l x n_topic x max_summary_sent_l self.beam_summary_input_sent_l = tf.reduce_sum( tf.cast(self.beam_mask_summary_sent, tf.int32), -1) # batch_l x n_topic beam_summary_soft_idxs = tf.one_hot(tf.where(self.beam_mask_summary_sent, \ self.beam_summary_idxs, tf.zeros_like(self.beam_summary_idxs)), depth=self.config.n_vocab) self.beam_summary_inputs = tf.tensordot(beam_summary_soft_idxs, \ self.embeddings, [[-1], [0]]) # batch_l x n_topic x max_beam_summary_sent_l x dim_emb # ------------------------------Discriminator------------------------------ with tf.variable_scope('emb'): with tf.variable_scope('sent', reuse=True): _, self.summary_state = encode_inputs( self, enc_inputs=self.summary_inputs, sent_l=self.summary_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 _, self.beam_summary_state = encode_inputs( self, enc_inputs=self.beam_summary_inputs, sent_l=self.beam_summary_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=True): self.probs_summary_topic_posterior, _, _ = \ encode_nhdp_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config) self.logits_summary_topic_posterior_ = tf_log( tf.matrix_diag_part(self.probs_summary_topic_posterior) ) # batch_l x n_topic self.logits_summary_topic_posterior = tf_clip_vals( self.logits_summary_topic_posterior_, self.probs_topic_posterior) # ------------------------------Optimizer and Loss------------------------------ with tf.name_scope('opt'): partition_doc = tf.cast(self.mask_doc, dtype=tf.int32) self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32) self.n_tokens = tf.reduce_sum(self.dec_sent_l) # ------------------------------Reconstruction Loss of Language Model------------------------------ # target and mask self.dec_target_idxs = self.t_variables[ 'dec_target_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l, maxlen=self.max_dec_sent_l, dtype=tf.float32) self.dec_target_idxs_flat = tf.reshape( self.dec_target_idxs, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) self.dec_mask_sent_flat = tf.reshape( self.dec_mask_sent, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) # nll for each token (summed over sentence) self.recon_max_sent_l = tf.minimum( self.max_dec_sent_l, self.max_output_sent_l) if self.config.sample else None losses_recon_flat = tf.reduce_sum( tf.contrib.seq2seq.sequence_loss( self.output_logits_flat[:, :self.recon_max_sent_l, :], self.dec_target_idxs_flat[:, :self.recon_max_sent_l], self.dec_mask_sent_flat[:, :self.recon_max_sent_l], average_across_timesteps=False, average_across_batch=False), -1) # batch_l*max_doc_l self.losses_recon = tf.reshape(losses_recon_flat, [self.batch_l, self.max_doc_l]) self.loss_recon = tf.reduce_mean( tf.dynamic_partition( self.losses_recon, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------ if self.config.topic_model: self.probs_sent_topic_prior = tf.expand_dims( self.probs_doc_topic_posterior, 1) # batch_l x 1 x n_topic else: self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \ self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \ (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1) self.loss_kl_prob = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_prob, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------ self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss( self ) # batch_l x max_doc_l x n_topic, sum over latent dimension self.losses_kl_sent_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_sent_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_sent_gmm = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_sent_gmm, partition_doc, num_partitions=2)[1]) # average over doc x batch # for monitor self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss( self) # batch_l x 1 x n_topic, sum over latent dimension self.losses_kl_topic_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_topic_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_topic_gmm = tf.reduce_mean( tf.dynamic_partition(self.losses_kl_topic_gmm, partition_doc, num_partitions=2)[1]) # ------------------------------Discriminator Loss------------------------------ if self.config.disc_topic: self.losses_disc_topic = -tf.reduce_sum( self.logits_summary_topic_posterior, -1) # batch_l, sum over topic self.loss_disc_topic = tf.reduce_sum( self.losses_disc_topic ) / self.n_sents # average over doc x batch else: self.loss_disc_topic = tf.constant(0, dtype=tf.float32) # ------------------------------Optimizer------------------------------ if self.config.anneal == 'linear': self.tau = tf.cast(tf.divide( self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32) self.beta = tf.minimum(1., self.config.beta_init + self.tau) elif self.config.anneal == 'cycle': self.tau = tf.cast(tf.divide( tf.mod(self.global_step, tf.constant(self.config.cycle_steps)), tf.constant(self.config.cycle_steps)), dtype=tf.float32) self.beta = tf.minimum( 1., self.config.beta_init + self.tau / (1. - self.config.r_cycle)) else: self.beta = tf.constant(1.) self.beta_disc = self.beta if self.config.beta_disc else tf.constant( 1.) def get_opt(loss, var_list, lr, global_step=None): if self.config.opt == 'adam': Optimizer = tf.train.AdamOptimizer elif self.config.opt == 'adagrad': Optimizer = tf.train.AdagradOptimizer optimizer = Optimizer(lr) grad_vars = optimizer.compute_gradients(loss=loss, var_list=var_list) clipped_grad_vars = [ (tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in grad_vars if grad is not None ] opt = optimizer.apply_gradients(clipped_grad_vars, global_step=global_step) return opt, grad_vars, clipped_grad_vars # ------------------------------Loss Setting------------------------------ if self.config.turn: self.loss = self.loss_recon + \ self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \ - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \ self.beta * self.loss_kl_root + \ self.topic_loss_recon + \ self.beta * self.topic_loss_kl_bow + \ self.beta * self.topic_loss_kl_prob + \ self.config.lam_reg * self.loss_reg self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('enc') + tf.trainable_variables('dec')), \ lr=self.config.lr, global_step=self.global_step) self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \ get_opt(self.loss_disc, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('disc')), lr=self.config.lr_disc) else: self.loss = self.loss_recon + \ self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \ - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \ self.beta * self.loss_kl_root + \ self.topic_loss_recon + \ self.beta * self.topic_loss_kl_bow + \ self.beta * self.topic_loss_kl_prob + \ self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \ self.config.lam_reg * self.loss_reg self.loss_disc = tf.constant(0, dtype=tf.float32) self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.config.lr, global_step=self.global_step) self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \ get_opt(self.loss, var_list=list(tf.trainable_variables('emb')+tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.config.lr, global_step=self.global_step) self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \ get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.config.lr, global_step=self.global_step) self.opt_disc = tf.constant(0, dtype=tf.float32) # for monitoring logdetcov self.logdetcovs_topic_posterior = tf_log( tf.linalg.det(self.covs_topic_posterior)) self.mask_depth = tf.tile(tf.expand_dims(tf.constant([self.config.tree_depth[topic_idx]-1 for topic_idx in self.config.topic_idxs], \ dtype=tf.int32), 0), [self.batch_l, 1]) self.depth_logdetcovs_topic_posterior = tf.dynamic_partition(self.logdetcovs_topic_posterior, self.mask_depth, \ num_partitions=max(self.config.tree_depth.values())) # list<depth> :n_topic_depth*batch_l # ------------------------------Evaluatiion------------------------------ self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \ self.loss_disc_topic, tf.constant(0)] self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \ self.loss_disc_topic, self.loss_kl_topic_gmm] self.loss_sum = (self.loss_recon + self.loss_kl_prob + self.loss_kl_sent_gmm) * self.n_sents
def build(self): self.global_step = tf.Variable(0, name='global_step', trainable=False) self.softmax_temperature = tf.maximum( \ self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \ self.config.min_temperature) with tf.name_scope('t_variables'): self.sample = self.t_variables['sample'] self.batch_l = self.t_variables['batch_l'] self.doc_l = self.t_variables['doc_l'] self.sent_l = self.t_variables['sent_l'] self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.n_oov = self.t_variables['n_oov'] self.max_doc_l = tf.reduce_max(self.doc_l) self.max_sent_l = tf.reduce_max(self.sent_l) self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32) self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32) self.enc_keep_prob = self.t_variables['enc_keep_prob'] # ------------------------------Encoder ------------------------------ with tf.variable_scope('enc'): with tf.variable_scope('word', reuse=False): self.enc_embeddings = get_embeddings(self) # get sentence embeddings self.enc_inputs = tf.nn.embedding_lookup( self.enc_embeddings, self.t_variables['enc_input_idxs'] ) # batch_l x max_doc_l x max_sent_l x dim_emb with tf.variable_scope('sent', reuse=False): self.sent_outputs, self.sent_state = \ encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2 self.memory_idxs = self.t_variables['enc_input_idxs'] # get sentence latents with tf.variable_scope('latents_sent', reuse=False): self.latents_sent_posterior, self.means_sent_posterior, self.logvars_sent_posterior = \ encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \ config=self.config, name='sent_posterior', min_logvar=self.config.min_logvar) # batch_l x max_doc_l x dim_latent # ------------------------------Discriminator------------------------------ with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=False): if self.config.latent: self.latents_probs_sent_topic_posterior, self.means_probs_sent_topic_posterior, self.logvars_probs_sent_topic_posterior = \ encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \ config=self.config, name='probs_topic_posterior') # batch_l x max_doc_l x dim_latent self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \ encode_nhdp_probs_topic_posterior(self, self.latents_probs_sent_topic_posterior.get_shape()[-1], \ self.latents_probs_sent_topic_posterior, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic else: self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \ encode_nhdp_probs_topic_posterior(self, self.sent_state.get_shape()[-1], self.sent_state, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic with tf.name_scope('latents_topic'): # get topic sentence posterior distribution for each document self.probs_topic_posterior = tf.reduce_sum( self.probs_sent_topic_posterior, 1) # batch_l x n_topic self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \ tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \ tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent self.means_topic_posterior = tf_clip_means( self.means_topic_posterior_, self.probs_topic_posterior) diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \ tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \ tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \ tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \ tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior = tf_clip_covs( self.covs_topic_posterior_, self.probs_topic_posterior) self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \ seed=self.config.seed, sample=self.sample) self.mean_root_prior = tf.zeros( [self.batch_l, self.config.dim_latent], dtype=tf.float32) self.cov_root_prior = tf.eye( self.config.dim_latent, batch_shape=[self.batch_l], dtype=tf.float32) * self.config.cov_root self.means_topic_prior = get_params_topic_prior( self, self.means_topic_posterior, self.mean_root_prior) # batch_l x n_topic x dim_latent self.covs_topic_prior = get_params_topic_prior( self, self.covs_topic_posterior, self.cov_root_prior ) # batch_l x n_topic x dim_latent x dim_latent # ------------------------------Decoder---------------------------------- with tf.variable_scope('dec'): with tf.variable_scope('word', reuse=False): self.dec_embeddings = get_embeddings(self) # decode for training sent with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=False): self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden) self.dec_cell = tf.contrib.rnn.DropoutWrapper( self.dec_cell, output_keep_prob=self.t_variables['dec_keep_prob']) self.latent_hidden_layer = tf.layers.Dense( units=self.config.dim_hidden, activation=tf.nn.relu, name='latent_hidden_linear') self.dec_sent_initial_state = self.latent_hidden_layer( self.latents_sent_posterior ) # batch_l x max_doc_l x dim_hidden self.output_layer = tf.layers.Dense(self.config.n_vocab, use_bias=False, dtype=tf.float32, name='output') if self.config.pointer: def pointer_layer(cell_outputs, cell_state, memory_idxs): output_probs = tf.nn.softmax( cell_outputs, -1) # batch_l*n_tile x n_vocab if self.config.oov: if output_probs.shape.rank == 3: output_probs = tf.concat([ output_probs, tf.zeros([ tf.shape(output_probs)[0], tf.shape(output_probs)[1], self.n_oov ], dtype=output_probs.dtype) ], -1) elif output_probs.shape.rank == 2: output_probs = tf.concat([ output_probs, tf.zeros([ tf.shape(output_probs)[0], self.n_oov ], dtype=output_probs.dtype) ], -1) prob_gen = cell_state.prob_gen # batch_l*n_tile x 1 alignments = cell_state.alignments # batch_l*n_tile x memory_l if output_probs.shape.rank == 3: alignments = tf.reshape(alignments, [ tf.shape(alignments)[0] * tf.shape(alignments)[1], tf.shape(alignments)[-1] ]) memory_indices = tf.concat([ tf.expand_dims( tf.tile( tf.expand_dims( tf.range(tf.shape(memory_idxs)[0]), -1), [1, tf.shape(memory_idxs)[1]]), -1), tf.expand_dims(memory_idxs, -1) ], -1) attn_probs = tf.tensor_scatter_nd_add( tf.zeros([ tf.shape(alignments)[0], self.config.n_vocab + self.n_oov ], dtype=tf.float32), indices=memory_indices, updates=alignments) # batch_l*n_tile x n_vocab if output_probs.shape.rank == 3: attn_probs = tf.reshape(attn_probs, [ tf.shape(output_probs)[0], tf.shape(output_probs)[1], tf.shape(output_probs)[-1] ]) pointer_probs = prob_gen * output_probs + ( 1. - prob_gen) * attn_probs pointer_logits = tf_log(pointer_probs) return pointer_logits self.prob_gen_layer = tf.layers.Dense( 1, use_bias=True, activation=tf.nn.sigmoid, dtype=tf.float32, name='pointer') self.pointer_layer = pointer_layer else: self.prob_gen_layer = None self.pointer_layer = None if self.config.attention: self.dec_sent_cell = wrap_attention(self, self.dec_cell, self.sent_outputs, n_tiled=self.max_doc_l) else: self.dec_sent_cell = self.dec_cell # teacher forcing self.dec_input_idxs = self.t_variables[ 'dec_input_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_inputs = tf.nn.embedding_lookup( self.dec_embeddings, self.dec_input_idxs ) # batch_l x max_doc_l x max_dec_sent_l x dim_emb # output_sent_l == dec_sent_l self.dec_sent_outputs, self.dec_sent_final_state, self.output_sent_l_flat = decode_output_logits_flat( self, dec_cell=self.dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, dec_inputs=self.dec_inputs, dec_sent_l=self.dec_sent_l, latents_input=self.latents_sent_posterior ) # batch_l*max_doc_l x max_output_sent_l x n_vocab self.output_logits_flat = self.dec_sent_outputs.rnn_output self.output_sent_l = tf.reshape(self.output_sent_l_flat, [self.batch_l, self.max_doc_l]) self.max_output_sent_l = tf.reduce_max(self.output_sent_l) self.output_logits = tf.reshape(self.output_logits_flat, \ [self.batch_l, self.max_doc_l, self.max_output_sent_l, tf.shape(self.output_logits_flat)[-1]], name='output_logits') if self.config.disc_gumbel: self.output_input_idxs = sample_gumbels( self.output_logits, self.softmax_temperature, self.config.seed, self.sample ) # batch_l x max_doc_l x max_output_sent_l x n_vocab else: self.output_input_idxs = self.output_logits # re-encode original sentence outputs self.output_inputs = tf.tensordot( self.output_input_idxs, self.enc_embeddings, axes=[[-1], [0]]) # batch_l x max_doc_l x max_sent_l x dim_emb self.output_input_sent_l = self.output_sent_l - 1 # to remove EOS self.mask_output_doc = self.mask_doc # decode for training topic probs with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): if self.config.attention: self.dec_topic_cell = wrap_attention( self, self.dec_cell, self.sent_outputs, n_tiled=self.config.n_topic) else: self.dec_topic_cell = self.dec_cell if not self.config.disc_mean: self.dec_topic_initial_state = self.latent_hidden_layer( self.latents_topic_posterior) self.dec_topic_outputs, self.dec_topic_final_state, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.latents_topic_posterior ) # batch_l*n_topic x max_summary_sent_l x n_vocab else: self.dec_topic_initial_state = self.latent_hidden_layer( self.means_topic_posterior) self.dec_topic_outputs, self.dec_topic_final_state, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.means_topic_posterior ) # batch_l*n_topic x max_summary_sent_l x n_vocab self.summary_sent_l = tf.reshape( self.summary_sent_l_flat, [self.batch_l, self.config.n_topic]) self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l) if self.config.disc_gumbel: summary_input_idxs_flat = self.dec_topic_outputs.sample_id else: summary_input_idxs_flat = self.dec_topic_outputs.rnn_output self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \ [self.batch_l, self.config.n_topic, self.max_summary_sent_l, tf.shape(summary_input_idxs_flat)[-1]], name='summary_input_idxs') # re-encode topic sentence outputs self.summary_inputs = tf.tensordot( self.summary_input_idxs, self.enc_embeddings, axes=[[-1], [ 0 ]]) # batch_l x n_topic x max_summary_sent_l x dim_emb self.summary_input_sent_l = self.summary_sent_l - 1 # to remove EOS self.mask_summary_doc = tf.ones( [self.batch_l, self.config.n_topic], dtype=tf.float32) # beam decode for inference of original sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): if self.config.nucleus < 1: self.beam_output_idxs = decode_sample_output_token_idxs( self, dec_cell=self.dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, latents_input=self.means_sent_posterior, name='beam_output_idxs') else: if self.config.attention: self.beam_dec_sent_cell = wrap_attention( self, self.dec_cell, self.sent_outputs, n_tiled=self.max_doc_l, beam_width=self.config.beam_width) else: self.beam_dec_sent_cell = self.dec_cell # infer original sentences self.beam_output_idxs = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, latents_input=self.means_sent_posterior, name='beam_output_idxs') # beam decode for inference of topic sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): if self.config.nucleus < 1: self.beam_summary_idxs = decode_sample_output_token_idxs( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, latents_input=self.means_topic_posterior, name='beam_summary_idxs') else: if self.config.attention: self.beam_dec_topic_cell = wrap_attention( self, self.dec_cell, self.sent_outputs, n_tiled=self.config.n_topic, beam_width=self.config.beam_width) else: self.beam_dec_topic_cell = self.dec_cell # infer topic sentences self.beam_summary_idxs = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, latents_input=self.means_topic_posterior, name='beam_summary_idxs') # ------------------------------Discriminator------------------------------ with tf.variable_scope('enc'): with tf.variable_scope('sent', reuse=True): _, self.output_state = encode_inputs( self, enc_inputs=self.output_inputs, sent_l=self.output_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 with tf.variable_scope('sent', reuse=True): _, self.summary_state = encode_inputs( self, enc_inputs=self.summary_inputs, sent_l=self.summary_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=True): self.probs_output_topic_posterior, _, _ = \ encode_nhdp_probs_topic_posterior(self, self.output_state.get_shape()[-1], self.output_state, self.mask_output_doc, self.config) self.logits_output_topic_posterior = tf_log( self.probs_output_topic_posterior ) # batch_l x max_doc_l x n_topic with tf.variable_scope('prob_topic', reuse=True): if self.config.latent: self.latents_probs_summary_topic_posterior, self.means_probs_summary_topic_posterior, self.logvars_probs_summary_topic_posterior = \ encode_latents_gauss(self.summary_state, dim_latent=self.config.dim_latent, sample=self.sample, \ config=self.config, name='probs_topic_posterior') # batch_l x max_doc_l x dim_latent self.probs_summary_topic_posterior, _, _ = \ encode_nhdp_probs_topic_posterior(self, self.latents_probs_summary_topic_posterior.get_shape()[-1], \ self.latents_probs_summary_topic_posterior, self.mask_summary_doc, self.config) # batch_l x max_doc_l x n_topic else: self.probs_summary_topic_posterior, _, _ = \ encode_nhdp_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config) self.logits_summary_topic_posterior = tf_log( tf.matrix_diag_part(self.probs_summary_topic_posterior)) # ------------------------------Optimizer and Loss------------------------------ with tf.name_scope('opt'): partition_doc = tf.cast(self.mask_doc, dtype=tf.int32) self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32) self.n_tokens = tf.reduce_sum(self.dec_sent_l) # ------------------------------Reconstruction Loss of Language Model------------------------------ # target and mask self.dec_target_idxs = self.t_variables[ 'dec_target_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l, maxlen=self.max_dec_sent_l, dtype=tf.float32) self.dec_target_idxs_flat = tf.reshape( self.dec_target_idxs, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) self.dec_mask_sent_flat = tf.reshape( self.dec_mask_sent, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) # nll for each token (summed over sentence) self.recon_max_sent_l = tf.minimum( self.max_dec_sent_l, self.max_output_sent_l) if self.config.sample else None losses_recon_flat = tf.reduce_sum( tf.contrib.seq2seq.sequence_loss( self.output_logits_flat[:, :self.recon_max_sent_l, :], self.dec_target_idxs_flat[:, :self.recon_max_sent_l], self.dec_mask_sent_flat[:, :self.recon_max_sent_l], average_across_timesteps=False, average_across_batch=False), -1) # batch_l*max_doc_l self.losses_recon = tf.reshape(losses_recon_flat, [self.batch_l, self.max_doc_l]) self.loss_recon = tf.reduce_mean( tf.dynamic_partition( self.losses_recon, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------ if self.config.latent: self.losses_kl_prob = compute_kl_losses( self.means_probs_sent_topic_posterior, self.logvars_probs_sent_topic_posterior) else: if self.config.prior: self.probs_sent_topic_prior = tf.tile( tf.expand_dims( tf.expand_dims( tf.constant(self.config.probs_topic_prior, dtype=tf.float32), 0), 0), [self.batch_l, self.max_doc_l, 1]) else: self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \ self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \ (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1) self.loss_kl_prob = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_prob, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------ self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss( self ) # batch_l x max_doc_l x n_topic, sum over latent dimension self.losses_kl_sent_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_sent_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_sent_gmm = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_sent_gmm, partition_doc, num_partitions=2)[1]) # average over doc x batch # for monitor self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss( self) # batch_l x 1 x n_topic, sum over latent dimension self.losses_kl_topic_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_topic_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_topic_gmm = tf.reduce_mean( tf.dynamic_partition(self.losses_kl_topic_gmm, partition_doc, num_partitions=2)[1]) # ------------------------------Discriminator Loss of Original sentences---------------------------- if self.config.disc_sent: # self.losses_disc_sent = -tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, self.logits_output_topic_posterior), -1) self.losses_disc_sent = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.probs_sent_topic_posterior, \ logits=self.logits_output_topic_posterior, dim=-1) # batch_l x max_doc_l, sum over topic) self.loss_disc_sent = tf.reduce_mean( tf.dynamic_partition(self.losses_disc_sent, partition_doc, num_partitions=2)[1]) else: self.loss_disc_sent = tf.constant(0, dtype=tf.float32) # ------------------------------Discriminator Loss of Topic sentences------------------------------ if self.config.disc_topic: self.losses_disc_topic = -tf_clip_vals( self.logits_summary_topic_posterior, self.probs_topic_posterior) # batch_l x n_topic if self.config.disc_weight: self.loss_disc_topic = tf.reduce_mean( self.losses_disc_topic) # average over topic x batch else: self.loss_disc_topic = tf.reduce_sum( self.losses_disc_topic ) / self.n_sents # average over doc x batch else: self.loss_disc_topic = tf.constant(0, dtype=tf.float32) # ------------------------------Coverage Loss------------------------------ if self.config.regularizer: self.losses_coverage_sent = compute_losses_coverage( self, self.dec_sent_final_state, self.dec_mask_sent, n_tiled=self.max_doc_l) # batch_l x max_doc_l self.loss_coverage_sent = tf.reduce_mean( tf.dynamic_partition( self.losses_coverage_sent, partition_doc, num_partitions=2)[1]) # average over doc x batch self.dec_mask_summary = tf.sequence_mask( self.summary_sent_l, maxlen=self.max_summary_sent_l, dtype=tf.float32) self.losses_coverage_topic = compute_losses_coverage( self, self.dec_topic_final_state, self.dec_mask_summary, n_tiled=self.config.n_topic) # batch_l x n_topic if self.config.disc_weight: self.loss_coverage_topic = tf.reduce_mean( self.losses_coverage_topic ) # average over topic x batch else: self.loss_coverage_topic = tf.reduce_sum( self.losses_coverage_topic ) / self.n_sents # average over doc x batch self.loss_coverage = self.loss_coverage_sent + self.loss_coverage_topic else: self.loss_coverage = tf.constant(0, dtype=tf.float32) # ------------------------------Optimizer------------------------------ if self.config.anneal == 'linear': self.tau = tf.cast(tf.divide( self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32) self.beta = tf.minimum(self.config.beta_last, self.config.beta_init + self.tau) else: self.beta = tf.constant(1.) self.beta_disc = self.beta if self.config.beta_disc else tf.constant( 1.) if self.config.lr_step: self.lr_step = tf.minimum( tf.cast(self.global_step, tf.float32)**(-1 / 2), tf.cast(self.global_step, tf.float32) * self.config.warmup**(-3 / 2)) else: self.lr_step = 1. self.lr = self.config.lr * self.lr_step self.lr_disc = self.config.lr_disc * self.lr_step def get_opt(loss, var_list, lr, global_step=None): if self.config.opt == 'adam': Optimizer = tf.train.AdamOptimizer elif self.config.opt == 'adagrad': Optimizer = tf.train.AdagradOptimizer optimizer = Optimizer(lr) grad_vars = optimizer.compute_gradients(loss=loss, var_list=var_list) clipped_grad_vars = [ (tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in grad_vars if grad is not None ] opt = optimizer.apply_gradients(clipped_grad_vars, global_step=global_step) return opt, grad_vars, clipped_grad_vars # ------------------------------Loss Setting------------------------------ if self.config.turn: self.loss = self.loss_recon + \ self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \ self.config.lam_reg * self.loss_coverage self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=list(tf.trainable_variables('enc') + tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step) self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \ get_opt(self.loss, var_list=list(tf.trainable_variables('enc')), lr=self.lr, global_step=self.global_step) self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \ get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step) self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \ self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \ get_opt(self.loss_disc, var_list=list(tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.lr_disc) self.opt_disc_infer, self.grad_vars_disc_infer, self.clipped_grad_vars_disc_infer = \ get_opt(self.loss_disc, var_list=list(tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.lr_disc) self.opt_disc_gen = tf.constant(0, dtype=tf.float32) self.opt_enc = tf.constant(0, dtype=tf.float32) elif self.config.control: self.loss = self.loss_recon + \ self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \ self.config.lam_reg * self.loss_coverage self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \ self.beta_disc * self.config.lam_disc * self.loss_disc_topic self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss+self.loss_disc, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step) self.opt_enc, self.grad_vars_enc, self.clipped_grad_vars_enc = \ get_opt(self.loss, var_list=list(tf.trainable_variables('enc')), lr=self.lr) self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \ get_opt(self.loss_disc, var_list=list(tf.trainable_variables('disc')), lr=self.lr_disc) else: self.loss = self.loss_recon + \ self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \ self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \ self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \ self.config.lam_reg * self.loss_coverage self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.lr, global_step=self.global_step) self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \ get_opt(self.loss, var_list=list(tf.trainable_variables('enc') + tf.trainable_variables('disc')), lr=self.lr, global_step=self.global_step) self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \ get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step) self.loss_disc = tf.constant(0, dtype=tf.float32) self.opt_enc = tf.constant(0, dtype=tf.float32) self.opt_disc = tf.constant(0, dtype=tf.float32) self.opt_disc_infer = tf.constant(0, dtype=tf.float32) self.opt_disc_gen = tf.constant(0, dtype=tf.float32) # for monitoring logdetcov self.logdetcovs_topic_posterior = tf_log( tf.linalg.det(self.covs_topic_posterior)) self.mask_depth = tf.tile(tf.expand_dims(tf.constant([self.config.tree_depth[topic_idx]-1 for topic_idx in self.config.topic_idxs], \ dtype=tf.int32), 0), [self.batch_l, 1]) self.depth_logdetcovs_topic_posterior = tf.dynamic_partition(self.logdetcovs_topic_posterior, self.mask_depth, \ num_partitions=max(self.config.tree_depth.values())) # list<depth> :n_topic_depth*batch_l # ------------------------------Evaluatiion------------------------------ self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \ self.loss_disc_sent, self.loss_disc_topic, self.loss_coverage] self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \ self.loss_disc_sent, self.loss_disc_topic, self.loss_kl_topic_gmm] self.loss_sum = (self.loss_recon + self.loss_kl_prob + self.loss_kl_sent_gmm) * self.n_sents
def pointer_layer(cell_outputs, cell_state, memory_idxs): output_probs = tf.nn.softmax( cell_outputs, -1) # batch_l*n_tile x n_vocab if self.config.oov: if output_probs.shape.rank == 3: output_probs = tf.concat([ output_probs, tf.zeros([ tf.shape(output_probs)[0], tf.shape(output_probs)[1], self.n_oov ], dtype=output_probs.dtype) ], -1) elif output_probs.shape.rank == 2: output_probs = tf.concat([ output_probs, tf.zeros([ tf.shape(output_probs)[0], self.n_oov ], dtype=output_probs.dtype) ], -1) prob_gen = cell_state.prob_gen # batch_l*n_tile x 1 alignments = cell_state.alignments # batch_l*n_tile x memory_l if output_probs.shape.rank == 3: alignments = tf.reshape(alignments, [ tf.shape(alignments)[0] * tf.shape(alignments)[1], tf.shape(alignments)[-1] ]) memory_indices = tf.concat([ tf.expand_dims( tf.tile( tf.expand_dims( tf.range(tf.shape(memory_idxs)[0]), -1), [1, tf.shape(memory_idxs)[1]]), -1), tf.expand_dims(memory_idxs, -1) ], -1) attn_probs = tf.tensor_scatter_nd_add( tf.zeros([ tf.shape(alignments)[0], self.config.n_vocab + self.n_oov ], dtype=tf.float32), indices=memory_indices, updates=alignments) # batch_l*n_tile x n_vocab if output_probs.shape.rank == 3: attn_probs = tf.reshape(attn_probs, [ tf.shape(output_probs)[0], tf.shape(output_probs)[1], tf.shape(output_probs)[-1] ]) pointer_probs = prob_gen * output_probs + ( 1. - prob_gen) * attn_probs pointer_logits = tf_log(pointer_probs) return pointer_logits