Exemple #1
0
    def build(self):
        def get_topic_loss_reg(topic_embeddings):
            topic_embeddings_norm = topic_embeddings / tf.norm(topic_embeddings, axis=1, keepdims=True)
            self.topic_dots = tf.clip_by_value(tf.matmul(topic_embeddings_norm, tf.transpose(topic_embeddings_norm)), -1., 1.)
            topic_loss_reg = tf.reduce_mean(tf.square(self.topic_dots - tf.eye(self.config.n_topic)))
            return topic_loss_reg
           
        # -------------- Build Model --------------
        tf.reset_default_graph()
        
        tf.set_random_seed(self.config.seed)
        
        self.t_variables['bow'] = tf.placeholder(tf.float32, [None, self.config.dim_bow])
        self.t_variables['keep_prob'] = tf.placeholder(tf.float32)
        
        # encode bow
        with tf.variable_scope('topic/enc', reuse=False):
            hidden_bow_ = tf.layers.Dense(units=self.config.dim_hidden_bow, activation=tf.nn.tanh, name='hidden_bow')(self.t_variables['bow'])
            hidden_bow = tf.layers.Dropout(self.t_variables['keep_prob'])(hidden_bow_)
            means_bow = tf.layers.Dense(units=self.config.dim_latent_bow, name='mean_bow')(hidden_bow)
            logvars_bow = tf.layers.Dense(units=self.config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow)
            latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors
            
            self.prob_topic = tf.layers.Dense(units=self.config.n_topic, activation=tf.nn.softmax, name='prob_topic')(latents_bow) # inference of topic probabilities

        # decode bow
        with tf.variable_scope('shared', reuse=False):
            self.bow_embeddings = tf.get_variable('emb', [self.config.dim_bow, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(), trainable=self.config.train_emb) # embeddings of vocab

        with tf.variable_scope('topic/dec', reuse=False):
            self.topic_embeddings = tf.get_variable('topic_emb', [self.config.n_topic, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of topics

            self.topic_bow = tf.nn.softmax(tf.matmul(self.topic_embeddings, self.bow_embeddings, transpose_b=True), 1) # bow vectors for each topic
            self.logits_bow = tf_log(tf.matmul(self.prob_topic, self.topic_bow)) # predicted bow distribution N_Batch x  V
            
        # define losses
        self.topic_losses_recon = -tf.reduce_sum(tf.multiply(self.t_variables['bow'], self.logits_bow), 1)
        self.topic_loss_recon = tf.reduce_mean(self.topic_losses_recon) # negative log likelihood of each words
        self.topic_losses_kl = compute_kl_loss(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std
        self.topic_loss_kl = tf.reduce_mean(self.topic_losses_kl, 0) #mean of kl_losses over batches        
        self.topic_loss_reg = get_topic_loss_reg(self.topic_embeddings)
        self.loss = self.topic_loss_recon + self.topic_loss_kl + self.config.reg * self.topic_loss_reg

        # define optimizer
        if self.config.opt == 'Adam':
            optimizer = tf.train.AdamOptimizer(self.config.lr)
        elif self.config.opt == 'Adagrad':
            optimizer = tf.train.AdagradOptimizer(self.config.lr)

        self.grad_vars = optimizer.compute_gradients(self.loss)
        self.clipped_grad_vars = [(tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in self.grad_vars]
        self.global_step = tf.Variable(0, name='global_step',trainable=False)
        self.opt = optimizer.apply_gradients(self.clipped_grad_vars, global_step=self.global_step)

        # monitor
        self.n_bow = tf.reduce_sum(self.t_variables['bow'], 1)
        self.topic_ppls = tf.divide(self.topic_losses_recon + self.topic_losses_kl, tf.maximum(1e-5, self.n_bow))
    
        # growth criteria
        self.n_topics = tf.multiply(tf.expand_dims(self.n_bow, -1), self.prob_topic)
Exemple #2
0
    def build(self):
        def get_prob_topic(tree_prob_leaf, prob_depth):
            tree_prob_topic = defaultdict(float)
            leaf_ancestor_idxs = {leaf_idx: get_ancestor_idxs(leaf_idx, self.child_to_parent_idxs) for leaf_idx in tree_prob_leaf}
            for leaf_idx, ancestor_idxs in leaf_ancestor_idxs.items():
                prob_leaf = tree_prob_leaf[leaf_idx]
                for i, ancestor_idx in enumerate(ancestor_idxs):
                    prob_ancestor = prob_leaf * tf.expand_dims(prob_depth[:, i], -1)
                    tree_prob_topic[ancestor_idx] += prob_ancestor
            prob_topic = tf.concat([tree_prob_topic[topic_idx] for topic_idx in self.topic_idxs], -1)
            return prob_topic     
        
        def get_tree_topic_bow(tree_topic_embeddings):
            tree_topic_bow = {}
            for topic_idx, depth in self.tree_depth.items():
                topic_embedding = tree_topic_embeddings[topic_idx]
                temperature = tf.constant(self.config.depth_temperature ** (1./depth), dtype=tf.float32)
                logits = tf.matmul(topic_embedding, self.bow_embeddings, transpose_b=True)
                tree_topic_bow[topic_idx] = softmax_with_temperature(logits, axis=-1, temperature=temperature)
                
            return tree_topic_bow
        
        def get_topic_loss_reg(tree_topic_embeddings):
            def get_tree_mask_reg(all_child_idxs):        
                tree_mask_reg = np.zeros([len(all_child_idxs), len(all_child_idxs)], dtype=np.float32)
                for parent_idx, child_idxs in self.tree_idxs.items():
                    neighbor_idxs = child_idxs
                    for neighbor_idx1 in neighbor_idxs:
                        for neighbor_idx2 in neighbor_idxs:
                            neighbor_index1 = all_child_idxs.index(neighbor_idx1)
                            neighbor_index2 = all_child_idxs.index(neighbor_idx2)
                            tree_mask_reg[neighbor_index1, neighbor_index2] = tree_mask_reg[neighbor_index2, neighbor_index1] = 1.
                return tree_mask_reg
            
            all_child_idxs = list(self.child_to_parent_idxs.keys())
            self.diff_topic_embeddings = tf.concat([tree_topic_embeddings[child_idx] - tree_topic_embeddings[self.child_to_parent_idxs[child_idx]] for child_idx in all_child_idxs], axis=0)
            diff_topic_embeddings_norm = self.diff_topic_embeddings / tf.norm(self.diff_topic_embeddings, axis=1, keepdims=True)
            self.topic_dots = tf.clip_by_value(tf.matmul(diff_topic_embeddings_norm, tf.transpose(diff_topic_embeddings_norm)), -1., 1.)        

            self.tree_mask_reg = get_tree_mask_reg(all_child_idxs)
            self.topic_losses_reg = tf.square(self.topic_dots - tf.eye(len(all_child_idxs))) * self.tree_mask_reg
            self.topic_loss_reg = tf.reduce_sum(self.topic_losses_reg) / tf.reduce_sum(self.tree_mask_reg)
            return self.topic_loss_reg
           
        # -------------- Build Model --------------
        tf.reset_default_graph()
        
        tf.set_random_seed(self.config.seed)
        
        self.t_variables['bow'] = tf.placeholder(tf.float32, [None, self.config.dim_bow])
        self.t_variables['keep_prob'] = tf.placeholder(tf.float32)
        
        # encode bow
        with tf.variable_scope('topic/enc', reuse=False):
            hidden_bow_ = tf.layers.Dense(units=self.config.dim_hidden_bow, activation=tf.nn.tanh, name='hidden_bow')(self.t_variables['bow'])
            hidden_bow = tf.layers.Dropout(self.t_variables['keep_prob'])(hidden_bow_)
            means_bow = tf.layers.Dense(units=self.config.dim_latent_bow, name='mean_bow')(hidden_bow)
            logvars_bow = tf.layers.Dense(units=self.config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow)
            latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors
            prob_layer = lambda h: tf.nn.sigmoid(tf.matmul(latents_bow, h, transpose_b=True))

            tree_sticks_topic, tree_states_sticks_topic = doubly_rnn(self.config.dim_latent_bow, self.tree_idxs, output_layer=prob_layer, name='sticks_topic')
            self.tree_prob_leaf = tsbp(tree_sticks_topic, self.tree_idxs)
            
            sticks_depth, _ = rnn(self.config.dim_latent_bow, self.n_depth, output_layer=prob_layer, name='prob_depth')
            self.prob_depth = sbp(sticks_depth, self.n_depth)

            self.prob_topic = get_prob_topic(self.tree_prob_leaf, self.prob_depth)# n_batch x n_topic

        # decode bow
        with tf.variable_scope('shared', reuse=False):
            self.bow_embeddings = tf.get_variable('emb', [self.config.dim_bow, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of vocab

        with tf.variable_scope('topic/dec', reuse=False):
            emb_layer = lambda h: tf.layers.Dense(units=self.config.dim_emb, name='output')(tf.nn.tanh(h))
            self.tree_topic_embeddings, tree_states_topic_embeddings = doubly_rnn(self.config.dim_emb, self.tree_idxs, output_layer=emb_layer, name='emb_topic')

            self.tree_topic_bow = get_tree_topic_bow(self.tree_topic_embeddings) # bow vectors for each topic

            self.topic_bow = tf.concat([self.tree_topic_bow[topic_idx] for topic_idx in self.topic_idxs], 0) # KxV
            self.logits_bow = tf_log(tf.matmul(self.prob_topic, self.topic_bow)) # predicted bow distribution N_Batch x  V
            
        # define losses
        self.topic_losses_recon = -tf.reduce_sum(tf.multiply(self.t_variables['bow'], self.logits_bow), 1)
        self.topic_loss_recon = tf.reduce_mean(self.topic_losses_recon) # negative log likelihood of each words

        self.topic_losses_kl = compute_kl_losses(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std
        self.topic_loss_kl = tf.reduce_mean(self.topic_losses_kl, 0) #mean of kl_losses over batches        
        
        self.topic_embeddings = tf.concat([self.tree_topic_embeddings[topic_idx] for topic_idx in self.topic_idxs], 0) # temporary
        self.topic_loss_reg = get_topic_loss_reg(self.tree_topic_embeddings)

        self.global_step = tf.Variable(0, name='global_step',trainable=False)

        self.loss = self.topic_loss_recon + self.topic_loss_kl + self.config.reg * self.topic_loss_reg

        # define optimizer
        if self.config.opt == 'Adam':
            optimizer = tf.train.AdamOptimizer(self.config.lr)
        elif self.config.opt == 'Adagrad':
            optimizer = tf.train.AdagradOptimizer(self.config.lr)

        self.grad_vars = optimizer.compute_gradients(self.loss)
        self.clipped_grad_vars = [(tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in self.grad_vars]
        self.opt = optimizer.apply_gradients(self.clipped_grad_vars, global_step=self.global_step)

        # monitor
        self.n_bow = tf.reduce_sum(self.t_variables['bow'], 1)
        self.topic_ppls = tf.divide(self.topic_losses_recon + self.topic_losses_kl, tf.maximum(1e-5, self.n_bow))
    
        # growth criteria
        self.n_topics = tf.multiply(tf.expand_dims(self.n_bow, -1), self.prob_topic)
        
        self.arcs_bow = tf.acos(tf.matmul(tf.linalg.l2_normalize(self.bow_embeddings, axis=-1), tf.linalg.l2_normalize(self.topic_embeddings, axis=-1), transpose_b=True)) # n_vocab x n_topic
        self.rads_bow = tf.multiply(tf.matmul(self.t_variables['bow'], self.arcs_bow), self.prob_topic) # n_batch x n_topic
Exemple #3
0
    def build(self):
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.softmax_temperature = tf.maximum( \
                                              self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \
                                              self.config.min_temperature)

        with tf.name_scope('t_variables'):
            self.sample = self.t_variables['sample']

            self.batch_l = self.t_variables['batch_l']
            self.doc_l = self.t_variables['doc_l']
            self.sent_l = self.t_variables['sent_l']
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l

            self.max_doc_l = tf.reduce_max(self.doc_l)
            self.max_sent_l = tf.reduce_max(self.sent_l)
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1

            self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32)
            self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32)

            mask_bow = np.zeros(self.config.n_vocab)
            mask_bow[self.config.bow_idxs] = 1.
            self.mask_bow = tf.constant(mask_bow, dtype=tf.float32)

            self.enc_keep_prob = self.t_variables['enc_keep_prob']

        # ------------------------------Encoder ------------------------------
        with tf.variable_scope('emb'):
            with tf.variable_scope('word', reuse=False):
                pad_embedding = tf.zeros([1, self.config.dim_emb],
                                         dtype=tf.float32)
                nonpad_embeddings = tf.get_variable('emb', [self.config.n_vocab-1, self.config.dim_emb], dtype=tf.float32, \
                                                                initializer=tf.contrib.layers.xavier_initializer())
                self.embeddings = tf.concat([pad_embedding, nonpad_embeddings],
                                            0)  # n_vocab x dim_emb
                self.bow_embeddings = tf.nn.embedding_lookup(
                    self.embeddings, self.config.bow_idxs)  # dim_bow x dim_emb

                # get sentence embeddings
                self.enc_input_idxs = tf.one_hot(
                    self.t_variables['enc_input_idxs'],
                    depth=self.config.n_vocab
                )  # batch_l x max_doc_l x max_sent_l x n_vocab
                self.enc_inputs = tf.tensordot(
                    self.enc_input_idxs, self.embeddings,
                    axes=[[-1],
                          [0]])  # batch_l x max_doc_l x max_sent_l x dim_emb

            with tf.variable_scope('sent', reuse=False):
                self.sent_outputs, self.sent_state = \
                    encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2

            # get sentence latents
            with tf.variable_scope('latents_sent', reuse=False):
                self.latents_sent_posterior, self.means_sent_posterior, self.logvars_sent_posterior = \
                        encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \
                                             config=self.config, name='sent_posterior', min_logvar=self.config.min_logvar) # batch_l x max_doc_l x dim_latent

        # ------------------------------Discriminator------------------------------
        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=False):
                # encode by TSNTM
                self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \
                    encode_nhdp_probs_topic_posterior(self, self.sent_state.get_shape()[-1], self.sent_state, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic

            with tf.name_scope('latents_topic'):
                # get topic sentence posterior distribution for each document
                self.probs_topic_posterior = tf.reduce_sum(
                    self.probs_sent_topic_posterior, 1)  # batch_l x n_topic

                self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \
                        tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent
                self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \
                        tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent
                self.means_topic_posterior = tf_clip_means(
                    self.means_topic_posterior_, self.probs_topic_posterior)

                diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \
                        tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent
                self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \
                        tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \
                        tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \
                        tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior = tf_clip_covs(
                    self.covs_topic_posterior_, self.probs_topic_posterior)

                self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \
                                                                      seed=self.config.seed, sample=self.sample)

                if not self.config.prior_root:
                    self.mean_root_prior = tf.zeros(
                        [self.batch_l, self.config.dim_latent],
                        dtype=tf.float32)
                    self.cov_root_prior = tf.eye(
                        self.config.dim_latent,
                        batch_shape=[self.batch_l],
                        dtype=tf.float32) * self.config.cov_root

                self.means_topic_prior = get_params_topic_prior(
                    self, self.means_topic_posterior,
                    self.mean_root_prior)  # batch_l x n_topic x dim_latent
                self.covs_topic_prior = get_params_topic_prior(
                    self, self.covs_topic_posterior, self.cov_root_prior
                )  # batch_l x n_topic x dim_latent x dim_latent

        # ------------------------------Decoder----------------------------------
        with tf.variable_scope('dec'):
            # decode for training sent
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=False):
                self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden)
                self.dec_cell = tf.contrib.rnn.DropoutWrapper(
                    self.dec_cell,
                    output_keep_prob=self.t_variables['dec_keep_prob'])
                self.dec_sent_cell = self.dec_cell
                self.latent_hidden_layer = tf.layers.Dense(
                    units=self.config.dim_hidden,
                    activation=tf.nn.relu,
                    name='latent_hidden_linear')
                self.dec_sent_initial_state = self.latent_hidden_layer(
                    self.latents_sent_posterior
                )  # batch_l x max_doc_l x dim_hidden
                self.output_layer = tf.layers.Dense(self.config.n_vocab,
                                                    use_bias=False,
                                                    name='out')

                if self.config.attention:
                    self.sent_outputs_flat = tf.reshape(
                        self.sent_outputs, [
                            self.batch_l * self.max_doc_l, self.max_sent_l,
                            self.config.dim_hidden * 2
                        ])
                    self.att_sent_l_flat = tf.reshape(
                        tf.maximum(self.sent_l, 1),
                        [self.batch_l * self.max_doc_l])
                    self.att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.config.dim_hidden,
                                                                                memory=self.sent_outputs_flat, \
                                                                                memory_sequence_length=self.att_sent_l_flat)
                    self.att_cell = tf.contrib.seq2seq.AttentionWrapper(
                        self.dec_cell,
                        attention_mechanism=self.att_sent_mechanism,
                        attention_layer_size=self.config.dim_hidden)
                    self.dec_sent_cell = self.att_cell

                # teacher forcing
                self.dec_input_idxs = self.t_variables[
                    'dec_input_idxs']  # batch_l x max_doc_l x max_dec_sent_l
                self.dec_inputs = tf.nn.embedding_lookup(
                    self.embeddings, self.dec_input_idxs
                )  # batch_l x max_doc_l x max_dec_sent_l x dim_emb

                # output_sent_l == dec_sent_l
                self.output_logits_flat, self.output_sent_l_flat = decode_output_logits_flat(
                    self,
                    dec_cell=self.dec_sent_cell,
                    dec_initial_state=self.dec_sent_initial_state,
                    dec_inputs=self.dec_inputs,
                    dec_sent_l=self.dec_sent_l,
                    latents_input=self.latents_sent_posterior
                )  # batch_l*max_doc_l x max_output_sent_l x n_vocab

                self.output_sent_l = tf.reshape(self.output_sent_l_flat,
                                                [self.batch_l, self.max_doc_l])
                self.max_output_sent_l = tf.reduce_max(self.output_sent_l)
                self.output_logits = tf.reshape(self.output_logits_flat, \
                                                [self.batch_l, self.max_doc_l, self.max_output_sent_l, self.config.n_vocab], name='output_logits')
                if self.config.disc_gumbel:
                    self.output_input_idxs = sample_gumbels(
                        self.output_logits, self.softmax_temperature,
                        self.config.seed, self.sample
                    )  # batch_l x max_doc_l x max_output_sent_l  x n_vocab
                else:
                    self.output_input_idxs = self.output_logits

            # decode for training topic probs
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.dec_topic_cell = self.dec_cell
                if self.config.attention:
                    self.topic_outputs_flat = tf.contrib.seq2seq.tile_batch(tf.reshape(self.sent_outputs, \
                                            [self.batch_l, self.max_doc_l*self.max_sent_l, self.sent_outputs.get_shape()[-1]]), \
                                            multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l x dim_hidden*2
                    self.score_mask = tf.contrib.seq2seq.tile_batch(tf.reshape(tf.sequence_mask(self.sent_l), \
                                            [self.batch_l, self.max_doc_l*self.max_sent_l]), multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l
                    self.hier_score = tf.reshape(tf.transpose(self.probs_sent_topic_posterior, [0, 2, 1]), \
                                            [self.batch_l*self.config.n_topic, self.max_doc_l]) # batch_l*n_topic x max_doc_l

                    self.att_topic_mechanism = HierarchicalAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.topic_outputs_flat,
                        score_mask=self.score_mask,
                        hier_score=self.hier_score)
                    self.att_topic_cell = AttentionWrapper(
                        self.dec_cell,
                        attention_mechanism=self.att_topic_mechanism,
                        attention_layer_size=self.config.dim_hidden)
                    self.dec_topic_cell = self.att_topic_cell

                if not self.config.disc_mean:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.latents_topic_posterior)
                    dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.latents_topic_posterior
                    )  # batch_l*max_doc_l x max_summary_sent_l x n_vocab
                else:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.means_topic_posterior)
                    dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.means_topic_posterior
                    )  # batch_l*max_doc_l x max_summary_sent_l x n_vocab

                self.summary_sent_l = tf.reshape(
                    self.summary_sent_l_flat,
                    [self.batch_l, self.config.n_topic])
                self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l)
                if self.config.disc_gumbel:
                    summary_input_idxs_flat = dec_topic_outputs.sample_id
                else:
                    summary_input_idxs_flat = dec_topic_outputs.rnn_output
                self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \
                                                     [self.batch_l, self.config.n_topic, self.max_summary_sent_l, self.config.n_vocab], name='summary_input_idxs')

                # re-encode topic sentence outputs
                self.summary_inputs = tf.tensordot(
                    self.summary_input_idxs, self.embeddings, axes=[[-1], [
                        0
                    ]])  # batch_l x n_topic x max_summary_sent_l x dim_emb
                self.summary_input_sent_l = self.summary_sent_l - 1  # to remove EOS
                self.mask_summary_sent = tf.sequence_mask(self.summary_input_sent_l, \
                                                          maxlen=self.max_summary_sent_l, dtype=tf.float32) # batch_l x n_topic x max_summary_sent_l
                self.mask_summary_doc = tf.ones(
                    [self.batch_l, self.config.n_topic], dtype=tf.float32)

            # beam decode for inference of original sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.beam_dec_sent_cell = self.dec_cell
                if self.config.attention:
                    self.beam_sent_outputs_flat = tf.contrib.seq2seq.tile_batch(
                        self.sent_outputs_flat,
                        multiplier=self.config.beam_width)
                    self.beam_att_sent_l_flat = tf.contrib.seq2seq.tile_batch(
                        self.att_sent_l_flat,
                        multiplier=self.config.beam_width)
                    self.beam_att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.beam_sent_outputs_flat,
                        memory_sequence_length=self.beam_att_sent_l_flat)
                    self.beam_dec_sent_cell = tf.contrib.seq2seq.AttentionWrapper(
                        self.beam_dec_sent_cell,
                        attention_mechanism=self.beam_att_sent_mechanism,
                        attention_layer_size=self.config.dim_hidden)

                # infer original sentences
                self.beam_output_idxs, _, _ = decode_beam_output_token_idxs(
                    self,
                    beam_dec_cell=self.beam_dec_sent_cell,
                    dec_initial_state=self.dec_sent_initial_state,
                    latents_input=self.means_sent_posterior,
                    name='beam_output_idxs')

            # beam decode for inference of topic sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.beam_dec_topic_cell = self.dec_cell
                if self.config.attention:
                    self.beam_topic_outputs_flat = tf.contrib.seq2seq.tile_batch(
                        self.topic_outputs_flat,
                        multiplier=self.config.beam_width)
                    self.beam_score_mask = tf.contrib.seq2seq.tile_batch(
                        self.score_mask, multiplier=self.config.beam_width)
                    self.beam_hier_score = tf.contrib.seq2seq.tile_batch(
                        self.hier_score, multiplier=self.config.beam_width)
                    self.beam_att_topic_mechanism = HierarchicalAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.beam_topic_outputs_flat,
                        score_mask=self.beam_score_mask,
                        hier_score=self.beam_hier_score)
                    self.beam_dec_topic_cell = AttentionWrapper(
                        self.beam_dec_topic_cell,
                        attention_mechanism=self.beam_att_topic_mechanism,
                        attention_layer_size=self.config.dim_hidden)

                # infer topic sentences
                self.beam_summary_idxs, _, _ = decode_beam_output_token_idxs(
                    self,
                    beam_dec_cell=self.beam_dec_topic_cell,
                    dec_initial_state=self.dec_topic_initial_state,
                    latents_input=self.latents_topic_posterior,
                    name='beam_summary_idxs')

                self.beam_mask_summary_sent = tf.logical_not(tf.equal(self.beam_summary_idxs, \
                                                                      self.config.EOS_IDX)) # batch_l x n_topic x max_summary_sent_l
                self.beam_summary_input_sent_l = tf.reduce_sum(
                    tf.cast(self.beam_mask_summary_sent, tf.int32),
                    -1)  # batch_l x n_topic
                beam_summary_soft_idxs = tf.one_hot(tf.where(self.beam_mask_summary_sent, \
                                                                            self.beam_summary_idxs, tf.zeros_like(self.beam_summary_idxs)), depth=self.config.n_vocab)
                self.beam_summary_inputs = tf.tensordot(beam_summary_soft_idxs, \
                                                        self.embeddings, [[-1], [0]]) # batch_l x n_topic x max_beam_summary_sent_l x dim_emb

        # ------------------------------Discriminator------------------------------
        with tf.variable_scope('emb'):
            with tf.variable_scope('sent', reuse=True):
                _, self.summary_state = encode_inputs(
                    self,
                    enc_inputs=self.summary_inputs,
                    sent_l=self.summary_input_sent_l
                )  # batch_l x max_doc_l x dim_hidden*2
                _, self.beam_summary_state = encode_inputs(
                    self,
                    enc_inputs=self.beam_summary_inputs,
                    sent_l=self.beam_summary_input_sent_l
                )  # batch_l x max_doc_l x dim_hidden*2

        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=True):
                self.probs_summary_topic_posterior, _, _ = \
                        encode_nhdp_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config)
                self.logits_summary_topic_posterior_ = tf_log(
                    tf.matrix_diag_part(self.probs_summary_topic_posterior)
                )  # batch_l x n_topic
                self.logits_summary_topic_posterior = tf_clip_vals(
                    self.logits_summary_topic_posterior_,
                    self.probs_topic_posterior)

        # ------------------------------Optimizer and Loss------------------------------
        with tf.name_scope('opt'):
            partition_doc = tf.cast(self.mask_doc, dtype=tf.int32)
            self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32)
            self.n_tokens = tf.reduce_sum(self.dec_sent_l)

            # ------------------------------Reconstruction Loss of Language Model------------------------------
            # target and mask
            self.dec_target_idxs = self.t_variables[
                'dec_target_idxs']  # batch_l x max_doc_l x max_dec_sent_l
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1
            self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l,
                                                  maxlen=self.max_dec_sent_l,
                                                  dtype=tf.float32)
            self.dec_target_idxs_flat = tf.reshape(
                self.dec_target_idxs,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])
            self.dec_mask_sent_flat = tf.reshape(
                self.dec_mask_sent,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])

            # nll for each token (summed over sentence)
            self.recon_max_sent_l = tf.minimum(
                self.max_dec_sent_l,
                self.max_output_sent_l) if self.config.sample else None
            losses_recon_flat = tf.reduce_sum(
                tf.contrib.seq2seq.sequence_loss(
                    self.output_logits_flat[:, :self.recon_max_sent_l, :],
                    self.dec_target_idxs_flat[:, :self.recon_max_sent_l],
                    self.dec_mask_sent_flat[:, :self.recon_max_sent_l],
                    average_across_timesteps=False,
                    average_across_batch=False), -1)  # batch_l*max_doc_l
            self.losses_recon = tf.reshape(losses_recon_flat,
                                           [self.batch_l, self.max_doc_l])
            self.loss_recon = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_recon, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------
            if self.config.topic_model:
                self.probs_sent_topic_prior = tf.expand_dims(
                    self.probs_doc_topic_posterior, 1)  # batch_l x 1 x n_topic
            else:
                self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \
                                                        self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics
            self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \
                                                            (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1)
            self.loss_kl_prob = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_prob, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------
            self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss(
                self
            )  # batch_l x max_doc_l x n_topic, sum over latent dimension
            self.losses_kl_sent_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_sent_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_sent_gmm = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_sent_gmm, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # for monitor
            self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss(
                self)  # batch_l x 1 x n_topic, sum over latent dimension
            self.losses_kl_topic_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_topic_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_topic_gmm = tf.reduce_mean(
                tf.dynamic_partition(self.losses_kl_topic_gmm,
                                     partition_doc,
                                     num_partitions=2)[1])

            # ------------------------------Discriminator Loss------------------------------
            if self.config.disc_topic:
                self.losses_disc_topic = -tf.reduce_sum(
                    self.logits_summary_topic_posterior,
                    -1)  # batch_l, sum over topic
                self.loss_disc_topic = tf.reduce_sum(
                    self.losses_disc_topic
                ) / self.n_sents  # average over doc x batch
            else:
                self.loss_disc_topic = tf.constant(0, dtype=tf.float32)

            # ------------------------------Optimizer------------------------------
            if self.config.anneal == 'linear':
                self.tau = tf.cast(tf.divide(
                    self.global_step, tf.constant(self.config.linear_steps)),
                                   dtype=tf.float32)
                self.beta = tf.minimum(1., self.config.beta_init + self.tau)
            elif self.config.anneal == 'cycle':
                self.tau = tf.cast(tf.divide(
                    tf.mod(self.global_step,
                           tf.constant(self.config.cycle_steps)),
                    tf.constant(self.config.cycle_steps)),
                                   dtype=tf.float32)
                self.beta = tf.minimum(
                    1., self.config.beta_init + self.tau /
                    (1. - self.config.r_cycle))
            else:
                self.beta = tf.constant(1.)

            self.beta_disc = self.beta if self.config.beta_disc else tf.constant(
                1.)

            def get_opt(loss, var_list, lr, global_step=None):
                if self.config.opt == 'adam':
                    Optimizer = tf.train.AdamOptimizer
                elif self.config.opt == 'adagrad':
                    Optimizer = tf.train.AdagradOptimizer

                optimizer = Optimizer(lr)
                grad_vars = optimizer.compute_gradients(loss=loss,
                                                        var_list=var_list)
                clipped_grad_vars = [
                    (tf.clip_by_value(grad, -self.config.grad_clip,
                                      self.config.grad_clip), var)
                    for grad, var in grad_vars if grad is not None
                ]
                opt = optimizer.apply_gradients(clipped_grad_vars,
                                                global_step=global_step)
                return opt, grad_vars, clipped_grad_vars

            # ------------------------------Loss Setting------------------------------
            if self.config.turn:
                self.loss = self.loss_recon + \
                             self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \
                                                            - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \
                             self.beta * self.loss_kl_root + \
                             self.topic_loss_recon + \
                             self.beta * self.topic_loss_kl_bow + \
                             self.beta * self.topic_loss_kl_prob + \
                             self.config.lam_reg * self.loss_reg

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('enc') + tf.trainable_variables('dec')), \
                                lr=self.config.lr, global_step=self.global_step)

                self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                                    self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob)

                self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \
                    get_opt(self.loss_disc, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('disc')), lr=self.config.lr_disc)

            else:
                self.loss = self.loss_recon + \
                             self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \
                                                            - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \
                             self.beta * self.loss_kl_root + \
                             self.topic_loss_recon + \
                             self.beta * self.topic_loss_kl_bow + \
                             self.beta * self.topic_loss_kl_prob + \
                             self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                             self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \
                             self.config.lam_reg * self.loss_reg
                self.loss_disc = tf.constant(0, dtype=tf.float32)

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.config.lr, global_step=self.global_step)
                self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('emb')+tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.config.lr, global_step=self.global_step)
                self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.config.lr, global_step=self.global_step)

                self.opt_disc = tf.constant(0, dtype=tf.float32)

            # for monitoring logdetcov
            self.logdetcovs_topic_posterior = tf_log(
                tf.linalg.det(self.covs_topic_posterior))
            self.mask_depth = tf.tile(tf.expand_dims(tf.constant([self.config.tree_depth[topic_idx]-1 for topic_idx in self.config.topic_idxs], \
                                                                 dtype=tf.int32), 0), [self.batch_l, 1])
            self.depth_logdetcovs_topic_posterior = tf.dynamic_partition(self.logdetcovs_topic_posterior, self.mask_depth, \
                                                                     num_partitions=max(self.config.tree_depth.values())) # list<depth> :n_topic_depth*batch_l

            # ------------------------------Evaluatiion------------------------------
            self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \
                                                    self.loss_disc_topic, tf.constant(0)]
            self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm,  \
                                                    self.loss_disc_topic, self.loss_kl_topic_gmm]
            self.loss_sum = (self.loss_recon + self.loss_kl_prob +
                             self.loss_kl_sent_gmm) * self.n_sents
Exemple #4
0
    def build(self):
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.softmax_temperature = tf.maximum( \
                                              self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \
                                              self.config.min_temperature)

        with tf.name_scope('t_variables'):
            self.sample = self.t_variables['sample']

            self.batch_l = self.t_variables['batch_l']
            self.doc_l = self.t_variables['doc_l']
            self.sent_l = self.t_variables['sent_l']
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l
            self.n_oov = self.t_variables['n_oov']

            self.max_doc_l = tf.reduce_max(self.doc_l)
            self.max_sent_l = tf.reduce_max(self.sent_l)
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1

            self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32)
            self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32)

            self.enc_keep_prob = self.t_variables['enc_keep_prob']

        # ------------------------------Encoder ------------------------------
        with tf.variable_scope('enc'):
            with tf.variable_scope('word', reuse=False):
                self.enc_embeddings = get_embeddings(self)
                # get sentence embeddings
                self.enc_inputs = tf.nn.embedding_lookup(
                    self.enc_embeddings, self.t_variables['enc_input_idxs']
                )  # batch_l x max_doc_l x max_sent_l x dim_emb

            with tf.variable_scope('sent', reuse=False):
                self.sent_outputs, self.sent_state = \
                    encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2
                self.memory_idxs = self.t_variables['enc_input_idxs']

            # get sentence latents
            with tf.variable_scope('latents_sent', reuse=False):
                self.latents_sent_posterior, self.means_sent_posterior, self.logvars_sent_posterior = \
                        encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \
                                             config=self.config, name='sent_posterior', min_logvar=self.config.min_logvar) # batch_l x max_doc_l x dim_latent

        # ------------------------------Discriminator------------------------------
        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=False):
                if self.config.latent:
                    self.latents_probs_sent_topic_posterior, self.means_probs_sent_topic_posterior, self.logvars_probs_sent_topic_posterior = \
                            encode_latents_gauss(self.sent_state, dim_latent=self.config.dim_latent, sample=self.sample, \
                                             config=self.config, name='probs_topic_posterior') # batch_l x max_doc_l x dim_latent
                    self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \
                            encode_nhdp_probs_topic_posterior(self, self.latents_probs_sent_topic_posterior.get_shape()[-1], \
                                                              self.latents_probs_sent_topic_posterior, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic
                else:
                    self.probs_sent_topic_posterior, self.tree_sent_sticks_path, self.tree_sent_sticks_depth = \
                        encode_nhdp_probs_topic_posterior(self, self.sent_state.get_shape()[-1], self.sent_state, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic

            with tf.name_scope('latents_topic'):
                # get topic sentence posterior distribution for each document
                self.probs_topic_posterior = tf.reduce_sum(
                    self.probs_sent_topic_posterior, 1)  # batch_l x n_topic

                self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \
                        tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent
                self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \
                        tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent
                self.means_topic_posterior = tf_clip_means(
                    self.means_topic_posterior_, self.probs_topic_posterior)

                diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \
                        tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent
                self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \
                        tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \
                        tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \
                        tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior = tf_clip_covs(
                    self.covs_topic_posterior_, self.probs_topic_posterior)

                self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \
                                                                      seed=self.config.seed, sample=self.sample)

                self.mean_root_prior = tf.zeros(
                    [self.batch_l, self.config.dim_latent], dtype=tf.float32)
                self.cov_root_prior = tf.eye(
                    self.config.dim_latent,
                    batch_shape=[self.batch_l],
                    dtype=tf.float32) * self.config.cov_root
                self.means_topic_prior = get_params_topic_prior(
                    self, self.means_topic_posterior,
                    self.mean_root_prior)  # batch_l x n_topic x dim_latent
                self.covs_topic_prior = get_params_topic_prior(
                    self, self.covs_topic_posterior, self.cov_root_prior
                )  # batch_l x n_topic x dim_latent x dim_latent

        # ------------------------------Decoder----------------------------------
        with tf.variable_scope('dec'):
            with tf.variable_scope('word', reuse=False):
                self.dec_embeddings = get_embeddings(self)

            # decode for training sent
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=False):
                self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden)
                self.dec_cell = tf.contrib.rnn.DropoutWrapper(
                    self.dec_cell,
                    output_keep_prob=self.t_variables['dec_keep_prob'])

                self.latent_hidden_layer = tf.layers.Dense(
                    units=self.config.dim_hidden,
                    activation=tf.nn.relu,
                    name='latent_hidden_linear')
                self.dec_sent_initial_state = self.latent_hidden_layer(
                    self.latents_sent_posterior
                )  # batch_l x max_doc_l x dim_hidden
                self.output_layer = tf.layers.Dense(self.config.n_vocab,
                                                    use_bias=False,
                                                    dtype=tf.float32,
                                                    name='output')

                if self.config.pointer:

                    def pointer_layer(cell_outputs, cell_state, memory_idxs):
                        output_probs = tf.nn.softmax(
                            cell_outputs, -1)  # batch_l*n_tile x n_vocab

                        if self.config.oov:
                            if output_probs.shape.rank == 3:
                                output_probs = tf.concat([
                                    output_probs,
                                    tf.zeros([
                                        tf.shape(output_probs)[0],
                                        tf.shape(output_probs)[1], self.n_oov
                                    ],
                                             dtype=output_probs.dtype)
                                ], -1)
                            elif output_probs.shape.rank == 2:
                                output_probs = tf.concat([
                                    output_probs,
                                    tf.zeros([
                                        tf.shape(output_probs)[0], self.n_oov
                                    ],
                                             dtype=output_probs.dtype)
                                ], -1)

                        prob_gen = cell_state.prob_gen  # batch_l*n_tile x 1
                        alignments = cell_state.alignments  # batch_l*n_tile x memory_l
                        if output_probs.shape.rank == 3:
                            alignments = tf.reshape(alignments, [
                                tf.shape(alignments)[0] *
                                tf.shape(alignments)[1],
                                tf.shape(alignments)[-1]
                            ])

                        memory_indices = tf.concat([
                            tf.expand_dims(
                                tf.tile(
                                    tf.expand_dims(
                                        tf.range(tf.shape(memory_idxs)[0]),
                                        -1), [1, tf.shape(memory_idxs)[1]]),
                                -1),
                            tf.expand_dims(memory_idxs, -1)
                        ], -1)
                        attn_probs = tf.tensor_scatter_nd_add(
                            tf.zeros([
                                tf.shape(alignments)[0],
                                self.config.n_vocab + self.n_oov
                            ],
                                     dtype=tf.float32),
                            indices=memory_indices,
                            updates=alignments)  # batch_l*n_tile x n_vocab

                        if output_probs.shape.rank == 3:
                            attn_probs = tf.reshape(attn_probs, [
                                tf.shape(output_probs)[0],
                                tf.shape(output_probs)[1],
                                tf.shape(output_probs)[-1]
                            ])

                        pointer_probs = prob_gen * output_probs + (
                            1. - prob_gen) * attn_probs
                        pointer_logits = tf_log(pointer_probs)
                        return pointer_logits

                    self.prob_gen_layer = tf.layers.Dense(
                        1,
                        use_bias=True,
                        activation=tf.nn.sigmoid,
                        dtype=tf.float32,
                        name='pointer')
                    self.pointer_layer = pointer_layer
                else:
                    self.prob_gen_layer = None
                    self.pointer_layer = None

                if self.config.attention:
                    self.dec_sent_cell = wrap_attention(self,
                                                        self.dec_cell,
                                                        self.sent_outputs,
                                                        n_tiled=self.max_doc_l)
                else:
                    self.dec_sent_cell = self.dec_cell

                # teacher forcing
                self.dec_input_idxs = self.t_variables[
                    'dec_input_idxs']  # batch_l x max_doc_l x max_dec_sent_l
                self.dec_inputs = tf.nn.embedding_lookup(
                    self.dec_embeddings, self.dec_input_idxs
                )  # batch_l x max_doc_l x max_dec_sent_l x dim_emb

                # output_sent_l == dec_sent_l
                self.dec_sent_outputs, self.dec_sent_final_state, self.output_sent_l_flat = decode_output_logits_flat(
                    self,
                    dec_cell=self.dec_sent_cell,
                    dec_initial_state=self.dec_sent_initial_state,
                    dec_inputs=self.dec_inputs,
                    dec_sent_l=self.dec_sent_l,
                    latents_input=self.latents_sent_posterior
                )  # batch_l*max_doc_l x max_output_sent_l x n_vocab

                self.output_logits_flat = self.dec_sent_outputs.rnn_output
                self.output_sent_l = tf.reshape(self.output_sent_l_flat,
                                                [self.batch_l, self.max_doc_l])
                self.max_output_sent_l = tf.reduce_max(self.output_sent_l)
                self.output_logits = tf.reshape(self.output_logits_flat, \
                                                [self.batch_l, self.max_doc_l, self.max_output_sent_l, tf.shape(self.output_logits_flat)[-1]], name='output_logits')
                if self.config.disc_gumbel:
                    self.output_input_idxs = sample_gumbels(
                        self.output_logits, self.softmax_temperature,
                        self.config.seed, self.sample
                    )  # batch_l x max_doc_l x max_output_sent_l  x n_vocab
                else:
                    self.output_input_idxs = self.output_logits

                # re-encode original sentence outputs
                self.output_inputs = tf.tensordot(
                    self.output_input_idxs,
                    self.enc_embeddings,
                    axes=[[-1],
                          [0]])  # batch_l x max_doc_l x max_sent_l x dim_emb
                self.output_input_sent_l = self.output_sent_l - 1  # to remove EOS
                self.mask_output_doc = self.mask_doc

            # decode for training topic probs
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                if self.config.attention:
                    self.dec_topic_cell = wrap_attention(
                        self,
                        self.dec_cell,
                        self.sent_outputs,
                        n_tiled=self.config.n_topic)
                else:
                    self.dec_topic_cell = self.dec_cell

                if not self.config.disc_mean:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.latents_topic_posterior)
                    self.dec_topic_outputs, self.dec_topic_final_state, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.latents_topic_posterior
                    )  # batch_l*n_topic x max_summary_sent_l x n_vocab
                else:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.means_topic_posterior)
                    self.dec_topic_outputs, self.dec_topic_final_state, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.means_topic_posterior
                    )  # batch_l*n_topic x max_summary_sent_l x n_vocab

                self.summary_sent_l = tf.reshape(
                    self.summary_sent_l_flat,
                    [self.batch_l, self.config.n_topic])
                self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l)
                if self.config.disc_gumbel:
                    summary_input_idxs_flat = self.dec_topic_outputs.sample_id
                else:
                    summary_input_idxs_flat = self.dec_topic_outputs.rnn_output
                self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \
                                                                     [self.batch_l, self.config.n_topic, self.max_summary_sent_l, tf.shape(summary_input_idxs_flat)[-1]], name='summary_input_idxs')

                # re-encode topic sentence outputs
                self.summary_inputs = tf.tensordot(
                    self.summary_input_idxs,
                    self.enc_embeddings,
                    axes=[[-1], [
                        0
                    ]])  # batch_l x n_topic x max_summary_sent_l x dim_emb
                self.summary_input_sent_l = self.summary_sent_l - 1  # to remove EOS
                self.mask_summary_doc = tf.ones(
                    [self.batch_l, self.config.n_topic], dtype=tf.float32)

            # beam decode for inference of original sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                if self.config.nucleus < 1:
                    self.beam_output_idxs = decode_sample_output_token_idxs(
                        self,
                        dec_cell=self.dec_sent_cell,
                        dec_initial_state=self.dec_sent_initial_state,
                        latents_input=self.means_sent_posterior,
                        name='beam_output_idxs')
                else:
                    if self.config.attention:
                        self.beam_dec_sent_cell = wrap_attention(
                            self,
                            self.dec_cell,
                            self.sent_outputs,
                            n_tiled=self.max_doc_l,
                            beam_width=self.config.beam_width)
                    else:
                        self.beam_dec_sent_cell = self.dec_cell

                    # infer original sentences
                    self.beam_output_idxs = decode_beam_output_token_idxs(
                        self,
                        beam_dec_cell=self.beam_dec_sent_cell,
                        dec_initial_state=self.dec_sent_initial_state,
                        latents_input=self.means_sent_posterior,
                        name='beam_output_idxs')

            # beam decode for inference of topic sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                if self.config.nucleus < 1:
                    self.beam_summary_idxs = decode_sample_output_token_idxs(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        latents_input=self.means_topic_posterior,
                        name='beam_summary_idxs')
                else:
                    if self.config.attention:
                        self.beam_dec_topic_cell = wrap_attention(
                            self,
                            self.dec_cell,
                            self.sent_outputs,
                            n_tiled=self.config.n_topic,
                            beam_width=self.config.beam_width)
                    else:
                        self.beam_dec_topic_cell = self.dec_cell

                    # infer topic sentences
                    self.beam_summary_idxs = decode_beam_output_token_idxs(
                        self,
                        beam_dec_cell=self.beam_dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        latents_input=self.means_topic_posterior,
                        name='beam_summary_idxs')

        # ------------------------------Discriminator------------------------------
        with tf.variable_scope('enc'):
            with tf.variable_scope('sent', reuse=True):
                _, self.output_state = encode_inputs(
                    self,
                    enc_inputs=self.output_inputs,
                    sent_l=self.output_input_sent_l
                )  # batch_l x max_doc_l x dim_hidden*2

            with tf.variable_scope('sent', reuse=True):
                _, self.summary_state = encode_inputs(
                    self,
                    enc_inputs=self.summary_inputs,
                    sent_l=self.summary_input_sent_l
                )  # batch_l x max_doc_l x dim_hidden*2

        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=True):
                self.probs_output_topic_posterior, _, _ = \
                                            encode_nhdp_probs_topic_posterior(self, self.output_state.get_shape()[-1], self.output_state, self.mask_output_doc, self.config)

                self.logits_output_topic_posterior = tf_log(
                    self.probs_output_topic_posterior
                )  # batch_l x max_doc_l x n_topic

            with tf.variable_scope('prob_topic', reuse=True):
                if self.config.latent:
                    self.latents_probs_summary_topic_posterior, self.means_probs_summary_topic_posterior, self.logvars_probs_summary_topic_posterior = \
                            encode_latents_gauss(self.summary_state, dim_latent=self.config.dim_latent, sample=self.sample, \
                                             config=self.config, name='probs_topic_posterior') # batch_l x max_doc_l x dim_latent
                    self.probs_summary_topic_posterior, _, _ = \
                            encode_nhdp_probs_topic_posterior(self, self.latents_probs_summary_topic_posterior.get_shape()[-1], \
                                                              self.latents_probs_summary_topic_posterior, self.mask_summary_doc, self.config) # batch_l x max_doc_l x n_topic
                else:
                    self.probs_summary_topic_posterior, _, _ = \
                            encode_nhdp_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config)

                self.logits_summary_topic_posterior = tf_log(
                    tf.matrix_diag_part(self.probs_summary_topic_posterior))

        # ------------------------------Optimizer and Loss------------------------------
        with tf.name_scope('opt'):
            partition_doc = tf.cast(self.mask_doc, dtype=tf.int32)
            self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32)
            self.n_tokens = tf.reduce_sum(self.dec_sent_l)

            # ------------------------------Reconstruction Loss of Language Model------------------------------
            # target and mask
            self.dec_target_idxs = self.t_variables[
                'dec_target_idxs']  # batch_l x max_doc_l x max_dec_sent_l
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1
            self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l,
                                                  maxlen=self.max_dec_sent_l,
                                                  dtype=tf.float32)
            self.dec_target_idxs_flat = tf.reshape(
                self.dec_target_idxs,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])
            self.dec_mask_sent_flat = tf.reshape(
                self.dec_mask_sent,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])

            # nll for each token (summed over sentence)
            self.recon_max_sent_l = tf.minimum(
                self.max_dec_sent_l,
                self.max_output_sent_l) if self.config.sample else None
            losses_recon_flat = tf.reduce_sum(
                tf.contrib.seq2seq.sequence_loss(
                    self.output_logits_flat[:, :self.recon_max_sent_l, :],
                    self.dec_target_idxs_flat[:, :self.recon_max_sent_l],
                    self.dec_mask_sent_flat[:, :self.recon_max_sent_l],
                    average_across_timesteps=False,
                    average_across_batch=False), -1)  # batch_l*max_doc_l
            self.losses_recon = tf.reshape(losses_recon_flat,
                                           [self.batch_l, self.max_doc_l])
            self.loss_recon = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_recon, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------
            if self.config.latent:
                self.losses_kl_prob = compute_kl_losses(
                    self.means_probs_sent_topic_posterior,
                    self.logvars_probs_sent_topic_posterior)
            else:
                if self.config.prior:
                    self.probs_sent_topic_prior = tf.tile(
                        tf.expand_dims(
                            tf.expand_dims(
                                tf.constant(self.config.probs_topic_prior,
                                            dtype=tf.float32), 0), 0),
                        [self.batch_l, self.max_doc_l, 1])
                else:
                    self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \
                                                                            self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics
                self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \
                                                                (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1)
            self.loss_kl_prob = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_prob, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------
            self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss(
                self
            )  # batch_l x max_doc_l x n_topic, sum over latent dimension
            self.losses_kl_sent_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_sent_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_sent_gmm = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_sent_gmm, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # for monitor
            self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss(
                self)  # batch_l x 1 x n_topic, sum over latent dimension
            self.losses_kl_topic_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_topic_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_topic_gmm = tf.reduce_mean(
                tf.dynamic_partition(self.losses_kl_topic_gmm,
                                     partition_doc,
                                     num_partitions=2)[1])

            # ------------------------------Discriminator Loss of Original sentences----------------------------
            if self.config.disc_sent:
                #                 self.losses_disc_sent = -tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, self.logits_output_topic_posterior), -1)
                self.losses_disc_sent = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.probs_sent_topic_posterior, \
                                                                        logits=self.logits_output_topic_posterior, dim=-1) # batch_l x max_doc_l, sum over topic)
                self.loss_disc_sent = tf.reduce_mean(
                    tf.dynamic_partition(self.losses_disc_sent,
                                         partition_doc,
                                         num_partitions=2)[1])
            else:
                self.loss_disc_sent = tf.constant(0, dtype=tf.float32)

            # ------------------------------Discriminator Loss of Topic sentences------------------------------
            if self.config.disc_topic:
                self.losses_disc_topic = -tf_clip_vals(
                    self.logits_summary_topic_posterior,
                    self.probs_topic_posterior)  # batch_l x n_topic
                if self.config.disc_weight:
                    self.loss_disc_topic = tf.reduce_mean(
                        self.losses_disc_topic)  # average over topic x batch
                else:
                    self.loss_disc_topic = tf.reduce_sum(
                        self.losses_disc_topic
                    ) / self.n_sents  # average over doc x batch
            else:
                self.loss_disc_topic = tf.constant(0, dtype=tf.float32)

            # ------------------------------Coverage Loss------------------------------
            if self.config.regularizer:
                self.losses_coverage_sent = compute_losses_coverage(
                    self,
                    self.dec_sent_final_state,
                    self.dec_mask_sent,
                    n_tiled=self.max_doc_l)  # batch_l x max_doc_l
                self.loss_coverage_sent = tf.reduce_mean(
                    tf.dynamic_partition(
                        self.losses_coverage_sent,
                        partition_doc,
                        num_partitions=2)[1])  # average over doc x batch

                self.dec_mask_summary = tf.sequence_mask(
                    self.summary_sent_l,
                    maxlen=self.max_summary_sent_l,
                    dtype=tf.float32)
                self.losses_coverage_topic = compute_losses_coverage(
                    self,
                    self.dec_topic_final_state,
                    self.dec_mask_summary,
                    n_tiled=self.config.n_topic)  # batch_l x n_topic
                if self.config.disc_weight:
                    self.loss_coverage_topic = tf.reduce_mean(
                        self.losses_coverage_topic
                    )  # average over topic x batch
                else:
                    self.loss_coverage_topic = tf.reduce_sum(
                        self.losses_coverage_topic
                    ) / self.n_sents  # average over doc x batch
                self.loss_coverage = self.loss_coverage_sent + self.loss_coverage_topic
            else:
                self.loss_coverage = tf.constant(0, dtype=tf.float32)

            # ------------------------------Optimizer------------------------------
            if self.config.anneal == 'linear':
                self.tau = tf.cast(tf.divide(
                    self.global_step, tf.constant(self.config.linear_steps)),
                                   dtype=tf.float32)
                self.beta = tf.minimum(self.config.beta_last,
                                       self.config.beta_init + self.tau)
            else:
                self.beta = tf.constant(1.)
            self.beta_disc = self.beta if self.config.beta_disc else tf.constant(
                1.)

            if self.config.lr_step:
                self.lr_step = tf.minimum(
                    tf.cast(self.global_step, tf.float32)**(-1 / 2),
                    tf.cast(self.global_step, tf.float32) *
                    self.config.warmup**(-3 / 2))
            else:
                self.lr_step = 1.
            self.lr = self.config.lr * self.lr_step
            self.lr_disc = self.config.lr_disc * self.lr_step

            def get_opt(loss, var_list, lr, global_step=None):
                if self.config.opt == 'adam':
                    Optimizer = tf.train.AdamOptimizer
                elif self.config.opt == 'adagrad':
                    Optimizer = tf.train.AdagradOptimizer

                optimizer = Optimizer(lr)
                grad_vars = optimizer.compute_gradients(loss=loss,
                                                        var_list=var_list)
                clipped_grad_vars = [
                    (tf.clip_by_value(grad, -self.config.grad_clip,
                                      self.config.grad_clip), var)
                    for grad, var in grad_vars if grad is not None
                ]
                opt = optimizer.apply_gradients(clipped_grad_vars,
                                                global_step=global_step)
                return opt, grad_vars, clipped_grad_vars

            # ------------------------------Loss Setting------------------------------
            if self.config.turn:
                self.loss = self.loss_recon + \
                                    self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \
                                    self.config.lam_reg * self.loss_coverage

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('enc') + tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step)
                self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('enc')), lr=self.lr, global_step=self.global_step)
                self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step)

                self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \
                                            self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                                            self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob)

                self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \
                    get_opt(self.loss_disc, var_list=list(tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.lr_disc)
                self.opt_disc_infer, self.grad_vars_disc_infer, self.clipped_grad_vars_disc_infer = \
                    get_opt(self.loss_disc, var_list=list(tf.trainable_variables('enc')+tf.trainable_variables('disc')), lr=self.lr_disc)
                self.opt_disc_gen = tf.constant(0, dtype=tf.float32)

                self.opt_enc = tf.constant(0, dtype=tf.float32)

            elif self.config.control:
                self.loss = self.loss_recon + \
                                    self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \
                                    self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \
                                    self.config.lam_reg * self.loss_coverage

                self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \
                                            self.beta_disc * self.config.lam_disc * self.loss_disc_topic

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss+self.loss_disc, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step)

                self.opt_enc, self.grad_vars_enc, self.clipped_grad_vars_enc = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('enc')), lr=self.lr)

                self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \
                    get_opt(self.loss_disc, var_list=list(tf.trainable_variables('disc')), lr=self.lr_disc)

            else:
                self.loss = self.loss_recon + \
                                    self.beta * tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) + \
                                    self.beta_disc * self.config.lam_disc * self.loss_disc_sent + \
                                    self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                                    self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \
                                    self.config.lam_reg * self.loss_coverage

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.lr, global_step=self.global_step)
                self.opt_infer, self.grad_vars_infer, self.clipped_grad_vars_infer = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('enc') + tf.trainable_variables('disc')), lr=self.lr, global_step=self.global_step)
                self.opt_gen, self.grad_vars_gen, self.clipped_grad_vars_gen = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('dec')), lr=self.lr, global_step=self.global_step)

                self.loss_disc = tf.constant(0, dtype=tf.float32)
                self.opt_enc = tf.constant(0, dtype=tf.float32)
                self.opt_disc = tf.constant(0, dtype=tf.float32)
                self.opt_disc_infer = tf.constant(0, dtype=tf.float32)
                self.opt_disc_gen = tf.constant(0, dtype=tf.float32)

            # for monitoring logdetcov
            self.logdetcovs_topic_posterior = tf_log(
                tf.linalg.det(self.covs_topic_posterior))
            self.mask_depth = tf.tile(tf.expand_dims(tf.constant([self.config.tree_depth[topic_idx]-1 for topic_idx in self.config.topic_idxs], \
                                                                 dtype=tf.int32), 0), [self.batch_l, 1])
            self.depth_logdetcovs_topic_posterior = tf.dynamic_partition(self.logdetcovs_topic_posterior, self.mask_depth, \
                                                                     num_partitions=max(self.config.tree_depth.values())) # list<depth> :n_topic_depth*batch_l

            # ------------------------------Evaluatiion------------------------------
            self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, \
                                                    self.loss_disc_sent, self.loss_disc_topic, self.loss_coverage]
            self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm,  \
                                                    self.loss_disc_sent, self.loss_disc_topic, self.loss_kl_topic_gmm]
            self.loss_sum = (self.loss_recon + self.loss_kl_prob +
                             self.loss_kl_sent_gmm) * self.n_sents
Exemple #5
0
                    def pointer_layer(cell_outputs, cell_state, memory_idxs):
                        output_probs = tf.nn.softmax(
                            cell_outputs, -1)  # batch_l*n_tile x n_vocab

                        if self.config.oov:
                            if output_probs.shape.rank == 3:
                                output_probs = tf.concat([
                                    output_probs,
                                    tf.zeros([
                                        tf.shape(output_probs)[0],
                                        tf.shape(output_probs)[1], self.n_oov
                                    ],
                                             dtype=output_probs.dtype)
                                ], -1)
                            elif output_probs.shape.rank == 2:
                                output_probs = tf.concat([
                                    output_probs,
                                    tf.zeros([
                                        tf.shape(output_probs)[0], self.n_oov
                                    ],
                                             dtype=output_probs.dtype)
                                ], -1)

                        prob_gen = cell_state.prob_gen  # batch_l*n_tile x 1
                        alignments = cell_state.alignments  # batch_l*n_tile x memory_l
                        if output_probs.shape.rank == 3:
                            alignments = tf.reshape(alignments, [
                                tf.shape(alignments)[0] *
                                tf.shape(alignments)[1],
                                tf.shape(alignments)[-1]
                            ])

                        memory_indices = tf.concat([
                            tf.expand_dims(
                                tf.tile(
                                    tf.expand_dims(
                                        tf.range(tf.shape(memory_idxs)[0]),
                                        -1), [1, tf.shape(memory_idxs)[1]]),
                                -1),
                            tf.expand_dims(memory_idxs, -1)
                        ], -1)
                        attn_probs = tf.tensor_scatter_nd_add(
                            tf.zeros([
                                tf.shape(alignments)[0],
                                self.config.n_vocab + self.n_oov
                            ],
                                     dtype=tf.float32),
                            indices=memory_indices,
                            updates=alignments)  # batch_l*n_tile x n_vocab

                        if output_probs.shape.rank == 3:
                            attn_probs = tf.reshape(attn_probs, [
                                tf.shape(output_probs)[0],
                                tf.shape(output_probs)[1],
                                tf.shape(output_probs)[-1]
                            ])

                        pointer_probs = prob_gen * output_probs + (
                            1. - prob_gen) * attn_probs
                        pointer_logits = tf_log(pointer_probs)
                        return pointer_logits