Exemple #1
0
def decoder(mel_targets, encoder_output, scope="decoder", training=True, reuse=None):

    with tf.variable_scope(scope, reuse=reuse):
      
      decoder_cell = TacotronDecoderWrapper(unidirectional_LSTM(training, layers=hp.dec_LSTM_layers, size=hp.dec_LSTM_size), training)

      attention_decoder = AttentionWrapper(
        decoder_cell,
        LocationBasedAttention(hp.attention_size, encoder_output),
        #BahdanauAttention(hp.attention_size, encoder_output),
        alignment_history=True,
        output_attention=False)

      decoder_state = attention_decoder.zero_state(batch_size=hp.batch_size, dtype=tf.float32)
      projection = tf.tile([[0.0]], [hp.batch_size, hp.n_mels])
      final_projection =tf.zeros([hp.batch_size, hp.T_y//hp.r, hp.n_mels], tf.float32)
      if hp.include_dones:
        LSTM_att =tf.zeros([hp.batch_size, hp.T_y//hp.r, hp.dec_LSTM_size*2], tf.float32)
      else:
        LSTM_att = 0
      step = 0

      def att_condition(step, projection, final_projection, decoder_state, mel_targets,LSTM_att):
        return step <  hp.T_y//hp.r

      def att_body(step, projection, final_projection, decoder_state, mel_targets,LSTM_att):
        if training:
          if step == 0:
            projection, decoder_state, _, LSTM_next = attention_decoder.call(tf.tile([[0.0]], [hp.batch_size, hp.n_mels]), decoder_state)
          else:
            projection, decoder_state, _, LSTM_next = attention_decoder.call(mel_targets[:, step-1, :], decoder_state)
        else:
          projection, decoder_state, _, LSTM_next = attention_decoder.call(projection, decoder_state)
        fprojection = tf.expand_dims(projection,axis=1)
        final_projection = tf.concat([final_projection,fprojection],axis=1)[:,1:,:]
        if hp.include_dones:
          fLSTM_next = tf.expand_dims(LSTM_next,axis=1)
          LSTM_att = tf.concat([LSTM_att,fLSTM_next],axis=1)[:,1:,:]
        return ((step+1), projection, final_projection, decoder_state, mel_targets,LSTM_att)
        
      res_loop = tf.while_loop(att_condition, att_body,
        loop_vars=[step, projection, final_projection, decoder_state, mel_targets,LSTM_att],
        parallel_iterations=hp.parallel_iterations, swap_memory=False)

      final_projection = res_loop[2]
      final_decoder_state = res_loop[3]
      concat_LSTM_att = res_loop[5]
      step = res_loop[0]

      if hp.print_shapes: print(final_projection)

      with tf.variable_scope("postnet"):
        tensor = final_projection
        for i in range(hp.dec_postnet_layers):
          tensor = conv1d(tensor,
            filters=hp.dec_postnet_filters,
            kernel_size=hp.dec_postnet_size,
            activation=tf.nn.tanh if i<4 else None,
            training=training,
            dropout_rate=hp.dropout_rate,
            scope="decoder_conv_{}".format(i)) # (N, Tx, c)
        #tensor = tf.layers.dense(tensor,hp.n_mels)
        tensor = tf.contrib.layers.fully_connected(tensor, hp.n_mels, activation_fn=None, biases_initializer=tf.zeros_initializer())
      if hp.print_shapes: print(tensor)

      mel_logits = final_projection + tensor
      if hp.print_shapes: print(mel_logits)

      if hp.include_dones:
        with tf.variable_scope("done_output"):
            done_output = fc_block(concat_LSTM_att, 2, training=training)
            done_output = tf.nn.sigmoid(done_output)
        if hp.print_shapes: print(done_output)
      else:
        done_output = None
        concat_LSTM_att = None

    return mel_logits, final_projection, done_output, final_decoder_state, concat_LSTM_att,step
Exemple #2
0
    ap.add_argument("--report-result",
                    type=str,
                    default=None,
                    help="File where to save the results.")
    ap.add_argument("-s",
                    "--sentences",
                    nargs='*',
                    type=int,
                    default=None,
                    help="Only use the specified sentences; 0-based")

    args = ap.parse_args()

    dependency_tree = Dependency(args.conll, args.tokens)
    bert_attns = AttentionWrapper(args.attentions,
                                  dependency_tree.wordpieces2tokens,
                                  args.sentences)

    if args.evaluate_only:
        if not args.json:
            raise ValueError(
                "JSON with head ensembles required in evaluate only mode!")
        with open(args.json, 'r') as inj:
            head_ensembles = json.load(inj)
        head_ensembles = {
            rl: HeadEnsemble.from_dict(**he_dict)
            for rl, he_dict in head_ensembles.items()
        }

    else:
        head_ensembles = dict()
    ax.set_yticks(np.arange(len(sentence_tokens)))
    ax.set_xticks(np.arange(len(sentence_tokens)))
    ax.set_xticklabels(sentence_tokens, rotation=90)
    ax.set_yticklabels(sentence_tokens)
    ax.set_ylim(top=-0.5, bottom=len(sentence_tokens) - 0.5)
    
    plt.savefig(out_file, dpi=300, format='pdf')


if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument("attentions", type=str, help="NPZ file with attentions")
    ap.add_argument("tokens", type=str, help="Labels (tokens) separated by spaces")
    ap.add_argument("conll", type=str, help="Conll file for head selection.")

    ap.add_argument("-layer_idcs", nargs="*", type=int, default=[5, 3], help = "layer indices to plot")
    ap.add_argument("-head_idcs", nargs="*", type=int, default=[4, 9], help="head indices to plot")
    ap.add_argument("-s", "--sentences", nargs='*', type=int, default=list(range(10)), help="Only use the specified sentences; 0-based")
    
    ap.add_argument("-vis-dir", type=str, default="../results", help="Directory where to save head visualizations")
   
    args = ap.parse_args()
    
    dependency_tree = Dependency(args.conll, args.tokens)
    bert_attns = AttentionWrapper(args.attentions, dependency_tree.wordpieces2tokens, args.sentences)
    
    for sent_idx, attn_mats in bert_attns:
        for l, h in zip(args.layer_idcs, args.head_idcs):
            out_file = os.path.join(args.vis_dir, f"L-{l}_H-{h}_sent-{sent_idx}.pdf")
            plot_head(attn_mats, dependency_tree.tokens[sent_idx], l, h, out_file)
Exemple #4
0
    def build(self):
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.softmax_temperature = tf.maximum( \
                                              self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \
                                              self.config.min_temperature)

        with tf.name_scope('t_variables'):
            self.sample = self.t_variables['sample']

            self.batch_l = self.t_variables['batch_l']
            self.doc_l = self.t_variables['doc_l']
            self.sent_l = self.t_variables['sent_l']
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l

            self.max_doc_l = tf.reduce_max(self.doc_l)
            self.max_sent_l = tf.reduce_max(self.sent_l)
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1

            self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32)
            self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32)

            mask_bow = np.zeros(self.config.n_vocab)
            mask_bow[self.config.bow_idxs] = 1.
            self.mask_bow = tf.constant(mask_bow, dtype=tf.float32)

            self.enc_keep_prob = self.t_variables['enc_keep_prob']

        # ------------------------------Encoder ------------------------------
        with tf.variable_scope('emb'):
            with tf.variable_scope('word', reuse=False):
                pad_embedding = tf.zeros([1, self.config.dim_emb],
                                         dtype=tf.float32)
                nonpad_embeddings = tf.get_variable('emb', [self.config.n_vocab-1, self.config.dim_emb], dtype=tf.float32, \
                                                                initializer=tf.contrib.layers.xavier_initializer())
                self.embeddings = tf.concat([pad_embedding, nonpad_embeddings],
                                            0)  # n_vocab x dim_emb
                self.bow_embeddings = tf.nn.embedding_lookup(
                    self.embeddings, self.config.bow_idxs)  # dim_bow x dim_emb

                # get sentence embeddings
                self.enc_input_idxs = tf.one_hot(
                    self.t_variables['enc_input_idxs'],
                    depth=self.config.n_vocab
                )  # batch_l x max_doc_l x max_sent_l x n_vocab
                self.enc_inputs = tf.tensordot(
                    self.enc_input_idxs, self.embeddings,
                    axes=[[-1],
                          [0]])  # batch_l x max_doc_l x max_sent_l x dim_emb

            with tf.variable_scope('sent', reuse=False):
                self.sent_outputs, self.sent_state = \
                    encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2

        with tf.variable_scope('enc'):
            # get sentence latents
            with tf.variable_scope('latents_sent', reuse=False):
                self.w_topic_posterior = tf.get_variable(
                    'topic_posterior/kernel', [
                        self.config.n_topic, self.sent_state.shape[-1],
                        self.config.dim_hidden
                    ],
                    dtype=tf.float32)
                self.b_topic_posterior = tf.get_variable(
                    'topic_posterior/bias',
                    [1, self.config.n_topic, self.config.dim_hidden],
                    dtype=tf.float32)

                self.topic_state = tf.reduce_sum(
                    self.sent_state * tf.expand_dims(self.mask_doc, -1),
                    -2) / tf.reduce_sum(self.mask_doc, -1, keepdims=True)
                self.hidden_topic_posterior = tf.tensordot(
                    self.topic_state, self.w_topic_posterior, axes=[[1], [1]]
                ) + self.b_topic_posterior  # batch_l x n_topic x dim_hidden

        # ------------------------------Discriminator------------------------------
        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=False):
                # encode by TSNTM
                self.probs_sent_topic_posterior, _, _ = \
                    encode_gsm_probs_topic_posterior(self, self.hidden_topic_posterior.get_shape()[-1], self.hidden_topic_posterior, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic

            with tf.name_scope('latents_topic'):
                # get topic sentence posterior distribution for each document
                self.probs_topic_posterior = tf.reduce_sum(
                    self.probs_sent_topic_posterior, 1)  # batch_l x n_topic

                self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \
                        tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent
                self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \
                        tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent
                self.means_topic_posterior = tf_clip_means(
                    self.means_topic_posterior_, self.probs_topic_posterior)

                diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \
                        tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent
                self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \
                        tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \
                        tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \
                        tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent
                self.covs_topic_posterior = tf_clip_covs(
                    self.covs_topic_posterior_, self.probs_topic_posterior)

                self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \
                                                                      seed=self.config.seed, sample=self.sample)

                self.means_topic_prior = tf.zeros(
                    [
                        self.batch_l, self.config.n_topic,
                        self.config.dim_latent
                    ],
                    dtype=tf.float32)  # batch_l x n_topic x dim_latent
                self.covs_topic_prior = tf.eye(
                    self.config.dim_latent,
                    batch_shape=[self.batch_l, self.config.n_topic],
                    dtype=tf.float32) * self.config.cov_root

        # ------------------------------Decoder----------------------------------
        with tf.variable_scope('dec'):
            # decode for training sent
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=False):
                self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden)
                self.dec_cell = tf.contrib.rnn.DropoutWrapper(
                    self.dec_cell,
                    output_keep_prob=self.t_variables['dec_keep_prob'])
                self.dec_sent_cell = self.dec_cell
                self.latent_hidden_layer = tf.layers.Dense(
                    units=self.config.dim_hidden,
                    activation=tf.nn.relu,
                    name='latent_hidden_linear')
                self.dec_sent_initial_state = self.latent_hidden_layer(
                    self.latents_sent_posterior
                )  # batch_l x max_doc_l x dim_hidden
                self.output_layer = tf.layers.Dense(self.config.n_vocab,
                                                    use_bias=False,
                                                    name='out')

                if self.config.attention:
                    self.sent_outputs_flat = tf.reshape(
                        self.sent_outputs, [
                            self.batch_l * self.max_doc_l, self.max_sent_l,
                            self.config.dim_hidden * 2
                        ])
                    self.att_sent_l_flat = tf.reshape(
                        tf.maximum(self.sent_l, 1),
                        [self.batch_l * self.max_doc_l])
                    self.att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.config.dim_hidden,
                                                                                memory=self.sent_outputs_flat, \
                                                                                memory_sequence_length=self.att_sent_l_flat)
                    self.att_cell = tf.contrib.seq2seq.AttentionWrapper(
                        self.dec_cell,
                        attention_mechanism=self.att_sent_mechanism,
                        attention_layer_size=self.config.dim_hidden)
                    self.dec_sent_cell = self.att_cell

                # teacher forcing
                self.dec_input_idxs = self.t_variables[
                    'dec_input_idxs']  # batch_l x max_doc_l x max_dec_sent_l
                self.dec_inputs = tf.nn.embedding_lookup(
                    self.embeddings, self.dec_input_idxs
                )  # batch_l x max_doc_l x max_dec_sent_l x dim_emb

                # output_sent_l == dec_sent_l
                self.output_logits_flat, self.output_sent_l_flat = decode_output_logits_flat(
                    self,
                    dec_cell=self.dec_sent_cell,
                    dec_initial_state=self.dec_sent_initial_state,
                    dec_inputs=self.dec_inputs,
                    dec_sent_l=self.dec_sent_l,
                    latents_input=self.latents_sent_posterior
                )  # batch_l*max_doc_l x max_output_sent_l x n_vocab

                self.output_sent_l = tf.reshape(self.output_sent_l_flat,
                                                [self.batch_l, self.max_doc_l])
                self.max_output_sent_l = tf.reduce_max(self.output_sent_l)
                self.output_logits = tf.reshape(self.output_logits_flat, \
                                    [self.batch_l, self.max_doc_l, self.max_output_sent_l, self.config.n_vocab], name='output_logits')
                if self.config.disc_gumbel:
                    self.output_input_idxs = sample_gumbels(
                        self.output_logits, self.softmax_temperature,
                        self.config.seed, self.sample
                    )  # batch_l x max_doc_l x max_output_sent_l  x n_vocab
                else:
                    self.output_input_idxs = self.output_logits

            # decode for training topic probs
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.dec_topic_cell = self.dec_cell
                if self.config.attention:
                    self.topic_outputs_flat = tf.contrib.seq2seq.tile_batch(tf.reshape(self.sent_outputs, \
                                            [self.batch_l, self.max_doc_l*self.max_sent_l, self.sent_outputs.get_shape()[-1]]), \
                                            multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l x dim_hidden*2
                    self.score_mask = tf.contrib.seq2seq.tile_batch(tf.reshape(tf.sequence_mask(self.sent_l), \
                                            [self.batch_l, self.max_doc_l*self.max_sent_l]), multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l
                    self.hier_score = tf.reshape(tf.transpose(self.probs_sent_topic_posterior, [0, 2, 1]), \
                                            [self.batch_l*self.config.n_topic, self.max_doc_l]) # batch_l*n_topic x max_doc_l

                    self.att_topic_mechanism = HierarchicalAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.topic_outputs_flat,
                        score_mask=self.score_mask,
                        hier_score=self.hier_score)
                    self.att_topic_cell = AttentionWrapper(
                        self.dec_cell,
                        attention_mechanism=self.att_topic_mechanism,
                        attention_layer_size=self.config.dim_hidden)
                    self.dec_topic_cell = self.att_topic_cell

                if not self.config.disc_mean:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.latents_topic_posterior)
                    dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.latents_topic_posterior
                    )  # batch_l*max_doc_l x max_summary_sent_l x n_vocab
                else:
                    self.dec_topic_initial_state = self.latent_hidden_layer(
                        self.means_topic_posterior)
                    dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat(
                        self,
                        dec_cell=self.dec_topic_cell,
                        dec_initial_state=self.dec_topic_initial_state,
                        softmax_temperature=self.softmax_temperature,
                        sample=self.sample,
                        latents_input=self.means_topic_posterior
                    )  # batch_l*max_doc_l x max_summary_sent_l x n_vocab

                self.summary_sent_l = tf.reshape(
                    self.summary_sent_l_flat,
                    [self.batch_l, self.config.n_topic])
                self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l)
                if self.config.disc_gumbel:
                    summary_input_idxs_flat = dec_topic_outputs.sample_id
                else:
                    summary_input_idxs_flat = dec_topic_outputs.rnn_output
                self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \
                                                     [self.batch_l, self.config.n_topic, self.max_summary_sent_l, self.config.n_vocab], name='summary_input_idxs')

                # re-encode topic sentence outputs
                self.summary_inputs = tf.tensordot(
                    self.summary_input_idxs, self.embeddings, axes=[[-1], [
                        0
                    ]])  # batch_l x n_topic x max_summary_sent_l x dim_emb
                self.summary_input_sent_l = self.summary_sent_l - 1  # to remove EOS
                self.mask_summary_sent = tf.sequence_mask(self.summary_input_sent_l, \
                                                          maxlen=self.max_summary_sent_l, dtype=tf.float32) # batch_l x n_topic x max_summary_sent_l
                self.mask_summary_doc = tf.ones(
                    [self.batch_l, self.config.n_topic], dtype=tf.float32)

            # beam decode for inference of original sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.beam_dec_sent_cell = self.dec_cell
                if self.config.attention:
                    self.beam_sent_outputs_flat = tf.contrib.seq2seq.tile_batch(
                        self.sent_outputs_flat,
                        multiplier=self.config.beam_width)
                    self.beam_att_sent_l_flat = tf.contrib.seq2seq.tile_batch(
                        self.att_sent_l_flat,
                        multiplier=self.config.beam_width)
                    self.beam_att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.beam_sent_outputs_flat,
                        memory_sequence_length=self.beam_att_sent_l_flat)
                    self.beam_dec_sent_cell = tf.contrib.seq2seq.AttentionWrapper(
                        self.beam_dec_sent_cell,
                        attention_mechanism=self.beam_att_sent_mechanism,
                        attention_layer_size=self.config.dim_hidden)

                # infer original sentences
                self.beam_output_idxs, _, _ = decode_beam_output_token_idxs(
                    self,
                    beam_dec_cell=self.beam_dec_sent_cell,
                    dec_initial_state=self.dec_sent_initial_state,
                    latents_input=self.means_sent_posterior,
                    name='beam_output_idxs')

            # beam decode for inference of topic sentences
            with tf.variable_scope(
                    'sent',
                    initializer=tf.contrib.layers.xavier_initializer(),
                    dtype=tf.float32,
                    reuse=True):
                self.beam_dec_topic_cell = self.dec_cell
                if self.config.attention:
                    self.beam_topic_outputs_flat = tf.contrib.seq2seq.tile_batch(
                        self.topic_outputs_flat,
                        multiplier=self.config.beam_width)
                    self.beam_score_mask = tf.contrib.seq2seq.tile_batch(
                        self.score_mask, multiplier=self.config.beam_width)
                    self.beam_hier_score = tf.contrib.seq2seq.tile_batch(
                        self.hier_score, multiplier=self.config.beam_width)
                    self.beam_att_topic_mechanism = HierarchicalAttention(
                        num_units=self.config.dim_hidden,
                        memory=self.beam_topic_outputs_flat,
                        score_mask=self.beam_score_mask,
                        hier_score=self.beam_hier_score)
                    self.beam_dec_topic_cell = AttentionWrapper(
                        self.beam_dec_topic_cell,
                        attention_mechanism=self.beam_att_topic_mechanism,
                        attention_layer_size=self.config.dim_hidden)

                # infer topic sentences
                self.beam_summary_idxs, _, _ = decode_beam_output_token_idxs(
                    self,
                    beam_dec_cell=self.beam_dec_topic_cell,
                    dec_initial_state=self.dec_topic_initial_state,
                    latents_input=self.latents_topic_posterior,
                    name='beam_summary_idxs')

                self.beam_mask_summary_sent = tf.logical_not(tf.equal(self.beam_summary_idxs, \
                                                                      self.config.EOS_IDX)) # batch_l x n_topic x max_summary_sent_l
                self.beam_summary_input_sent_l = tf.reduce_sum(
                    tf.cast(self.beam_mask_summary_sent, tf.int32),
                    -1)  # batch_l x n_topic
                beam_summary_soft_idxs = tf.one_hot(tf.where(self.beam_mask_summary_sent, \
                                                                            self.beam_summary_idxs, tf.zeros_like(self.beam_summary_idxs)), depth=self.config.n_vocab)
                self.beam_summary_inputs = tf.tensordot(beam_summary_soft_idxs, \
                                                        self.embeddings, [[-1], [0]]) # batch_l x n_topic x max_beam_summary_sent_l x dim_emb

        # ------------------------------Discriminator------------------------------
        # encode by MLP
        if self.config.enc == 'mlp':
            with tf.variable_scope('disc'):
                with tf.variable_scope('prob_topic', reuse=True):
                    self.summary_state = encode_states(self, enc_inputs=self.summary_inputs, mask_sent=self.mask_summary_sent, \
                                                                   enc_keep_prob=self.enc_keep_prob, config=self.config) # batch_l x n_topic x dim_hidden
        elif self.config.enc == 'bow':
            with tf.variable_scope('disc'):
                with tf.variable_scope('prob_topic', reuse=True):
                    self.bow_summary_input_idxs = tf.multiply(
                        self.summary_input_idxs, self.mask_bow)
                    self.bow_summary_inputs = tf.tensordot(
                        self.bow_summary_input_idxs,
                        self.embeddings,
                        axes=[[-1], [0]
                              ])  # batch_l x max_doc_l x max_sent_l x dim_emb
                    self.mask_summary_bow = tf.reduce_sum(
                        self.bow_summary_input_idxs, -1)
                    self.summary_state = encode_states(self, enc_inputs=self.bow_summary_inputs, mask_sent=self.mask_summary_bow, \
                                                                   enc_keep_prob=self.enc_keep_prob, config=self.config) # batch_l x max_doc_l x dim_hidden
        elif self.config.enc == 'rnn':
            with tf.variable_scope('emb'):
                with tf.variable_scope('sent', reuse=True):
                    _, self.summary_state = encode_inputs(
                        self,
                        enc_inputs=self.summary_inputs,
                        sent_l=self.summary_input_sent_l
                    )  # batch_l x max_doc_l x dim_hidden*2
                    _, self.beam_summary_state = encode_inputs(
                        self,
                        enc_inputs=self.beam_summary_inputs,
                        sent_l=self.beam_summary_input_sent_l
                    )  # batch_l x max_doc_l x dim_hidden*2

        with tf.variable_scope('disc'):
            with tf.variable_scope('prob_topic', reuse=True):
                self.probs_summary_topic_posterior, _, _ = \
                        encode_gsm_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config)
                self.logits_summary_topic_posterior_ = tf_log(
                    tf.matrix_diag_part(self.probs_summary_topic_posterior)
                )  # batch_l x n_topic
                self.logits_summary_topic_posterior = tf_clip_vals(
                    self.logits_summary_topic_posterior_,
                    self.probs_topic_posterior)

        # ------------------------------Optimizer and Loss------------------------------
        with tf.name_scope('opt'):
            partition_doc = tf.cast(self.mask_doc, dtype=tf.int32)
            self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32)
            self.n_tokens = tf.reduce_sum(self.dec_sent_l)

            # ------------------------------Reconstruction Loss of Language Model------------------------------
            # target and mask
            self.dec_target_idxs = self.t_variables[
                'dec_target_idxs']  # batch_l x max_doc_l x max_dec_sent_l
            self.dec_sent_l = self.t_variables[
                'dec_sent_l']  # batch_l x max_doc_l
            self.max_dec_sent_l = tf.reduce_max(
                self.dec_sent_l)  # = max_sent_l + 1
            self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l,
                                                  maxlen=self.max_dec_sent_l,
                                                  dtype=tf.float32)
            self.dec_target_idxs_flat = tf.reshape(
                self.dec_target_idxs,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])
            self.dec_mask_sent_flat = tf.reshape(
                self.dec_mask_sent,
                [self.batch_l * self.max_doc_l, self.max_dec_sent_l])

            # nll for each token (summed over sentence)
            self.recon_max_sent_l = tf.minimum(
                self.max_dec_sent_l,
                self.max_output_sent_l) if self.config.sample else None
            losses_recon_flat = tf.reduce_sum(
                tf.contrib.seq2seq.sequence_loss(
                    self.output_logits_flat[:, :self.recon_max_sent_l, :],
                    self.dec_target_idxs_flat[:, :self.recon_max_sent_l],
                    self.dec_mask_sent_flat[:, :self.recon_max_sent_l],
                    average_across_timesteps=False,
                    average_across_batch=False), -1)  # batch_l*max_doc_l
            self.losses_recon = tf.reshape(losses_recon_flat,
                                           [self.batch_l, self.max_doc_l])
            self.loss_recon = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_recon, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------
            if self.config.topic_model:
                self.probs_sent_topic_prior = tf.expand_dims(
                    self.probs_doc_topic_posterior, 1)  # batch_l x 1 x n_topic
            else:
                self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \
                                                        self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics
            self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \
                                                            (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1)
            self.loss_kl_prob = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_prob, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------
            self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss(
                self
            )  # batch_l x max_doc_l x n_topic, sum over latent dimension
            self.losses_kl_sent_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_sent_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_sent_gmm = tf.reduce_mean(
                tf.dynamic_partition(
                    self.losses_kl_sent_gmm, partition_doc,
                    num_partitions=2)[1])  # average over doc x batch

            # ------------------------------KL divergence Loss of Topic Latents Distribution------------------------------
            if self.config.reverse_kl:
                self.losses_kl_topic_pairs_gauss = compute_kl_losses_topic_paris_gauss(
                    self)
                self.losses_kl_topic_gauss_reverse = tf.reduce_sum(self.losses_kl_topic_pairs_gauss * self.config.mask_tree[None, None, :, :], -1) / \
                                        np.maximum(np.sum(self.config.mask_tree[None, None, :, :], -1), 1) # batch_l x 1 x n_topic, mean over other child topics
                self.losses_kl_topic_gmm_reverse = tf.reduce_sum(
                    tf.multiply(self.probs_sent_topic_posterior,
                                self.losses_kl_topic_gauss_reverse),
                    -1)  # batch_l x max_doc_l, sum over topics
                self.loss_kl_topic_gmm_reverse = tf.reduce_mean(
                    tf.dynamic_partition(self.losses_kl_topic_gmm_reverse,
                                         partition_doc,
                                         num_partitions=2)[1])
            else:
                self.loss_kl_topic_gmm_reverse = tf.constant(0.,
                                                             dtype=tf.float32)

            # for monitor
            self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss(
                self)  # batch_l x 1 x n_topic, sum over latent dimension
            self.losses_kl_topic_gmm = tf.reduce_sum(
                tf.multiply(self.probs_sent_topic_posterior,
                            self.losses_kl_topic_gauss),
                -1)  # batch_l x max_doc_l, sum over topics
            self.loss_kl_topic_gmm = tf.reduce_mean(
                tf.dynamic_partition(self.losses_kl_topic_gmm,
                                     partition_doc,
                                     num_partitions=2)[1])

            # ------------------------------KL divergence Loss of Root State Distribution------------------------------
            if self.config.prior_root:
                self.losses_kl_root = compute_kl_losses(
                    self.means_state_root_posterior,
                    self.logvars_state_root_posterior)  # batch_l x max_doc_l
                self.loss_kl_root = tf.reduce_sum(
                    self.losses_kl_root) / tf.cast(
                        tf.reduce_sum(self.doc_l),
                        dtype=tf.float32)  # average over doc x batch
            else:
                self.loss_kl_root = tf.constant(0, dtype=tf.float32)

            # ------------------------------Discriminator Loss------------------------------
            if self.config.disc_topic:
                self.losses_disc_topic = -tf.reduce_sum(
                    self.logits_summary_topic_posterior,
                    -1)  # batch_l, sum over topic
                self.loss_disc_topic = tf.reduce_sum(
                    self.losses_disc_topic
                ) / self.n_sents  # average over doc x batch
            else:
                self.loss_disc_topic = tf.constant(0, dtype=tf.float32)

            # ------------------------------Loss of Topic Model------------------------------
            if self.config.topic_model:
                # recon
                self.topic_losses_recon = -tf.reduce_sum(
                    tf.multiply(self.t_variables['doc_bows'], self.logits_bow),
                    -1)  # n_batch, sum over n_bow
                self.topic_loss_recon = tf.reduce_mean(
                    self.topic_losses_recon)  # average over doc x batch

                # kl_bow
                self.means_topic_bow_prior = tf.squeeze(get_params_topic_prior(self, tf.expand_dims(self.means_topic_bow_posterior, 0), \
                                                                    tf.zeros([1, self.config.dim_latent], dtype=tf.float32)), 0) # n_topic x dim_latent
                self.logvars_topic_bow_prior = tf.squeeze(get_params_topic_prior(self, tf.expand_dims(self.logvars_topic_bow_posterior, 0), \
                                                                                tf.zeros([1, self.config.dim_latent], dtype=tf.float32)), 0) # n_topic x dim_latent
                self.topic_losses_kl_bow = compute_kl_losses(self.means_topic_bow_posterior, self.logvars_topic_bow_posterior, \
                                                                            means_prior=self.means_topic_bow_prior, logvars_prior=self.logvars_topic_bow_prior) # n_topic
                self.topic_loss_kl_bow = tf.reduce_mean(
                    self.topic_losses_kl_bow)  # average over doc x batch

                # kl_prob
                self.topic_losses_kl_prob = compute_kl_losses(
                    self.means_probs_doc_topic_posterior,
                    self.logvars_probs_doc_topic_posterior)  # batch_l
                self.topic_loss_kl_prob = tf.reduce_mean(
                    self.topic_losses_kl_prob)  # average over doc x batch
            else:
                self.topic_loss_recon = tf.constant(0, dtype=tf.float32)
                self.topic_loss_kl_bow = tf.constant(0, dtype=tf.float32)
                self.topic_loss_kl_prob = tf.constant(0, dtype=tf.float32)

            # ------------------------------Topic Regularization Loss------------------------------
            if self.config.reg != '':
                if self.config.reg == 'mean':
                    self.topic_dots = self.get_topic_dots(
                        self.means_topic_posterior
                    )  # batch_l x n_topic-1 x n_topic-1
                elif self.config.reg == 'bow':
                    self.topic_dots = self.get_topic_dots(
                        tf.expand_dims(
                            self.topic_bow,
                            0))  # batch_l(=1) x n_topic-1 x n_topic-1

                self.losses_reg = tf.reduce_sum(tf.square(self.topic_dots - tf.eye(len(self.config.all_child_idxs))) * self.config.mask_tree_reg, [1, 2])\
                                        / tf.reduce_sum(self.config.mask_tree_reg) # batch_l
                self.loss_reg = tf.reduce_mean(
                    self.losses_reg)  # average over batch
            else:
                self.loss_reg = tf.constant(0, dtype=tf.float32)

            # ------------------------------Optimizer------------------------------
            if self.config.anneal == 'linear':
                self.tau = tf.cast(tf.divide(
                    self.global_step, tf.constant(self.config.linear_steps)),
                                   dtype=tf.float32)
                self.beta = tf.minimum(1., self.config.beta_init + self.tau)
            elif self.config.anneal == 'cycle':
                self.tau = tf.cast(tf.divide(
                    tf.mod(self.global_step,
                           tf.constant(self.config.cycle_steps)),
                    tf.constant(self.config.cycle_steps)),
                                   dtype=tf.float32)
                self.beta = tf.minimum(
                    1., self.config.beta_init + self.tau /
                    (1. - self.config.r_cycle))
            else:
                self.beta = tf.constant(1.)

            self.beta_disc = self.beta if self.config.beta_disc else tf.constant(
                1.)

            def get_opt(loss, var_list, lr, global_step=None):
                if self.config.opt == 'adam':
                    Optimizer = tf.train.AdamOptimizer
                elif self.config.opt == 'adagrad':
                    Optimizer = tf.train.AdagradOptimizer

                optimizer = Optimizer(lr)
                grad_vars = optimizer.compute_gradients(loss=loss,
                                                        var_list=var_list)
                clipped_grad_vars = [
                    (tf.clip_by_value(grad, -self.config.grad_clip,
                                      self.config.grad_clip), var)
                    for grad, var in grad_vars if grad is not None
                ]
                opt = optimizer.apply_gradients(clipped_grad_vars,
                                                global_step=global_step)
                return opt, grad_vars, clipped_grad_vars

            # ------------------------------Loss Setting------------------------------
            if self.config.turn:
                self.loss = self.loss_recon + \
                             self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \
                                                            - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \
                             self.beta * self.loss_kl_root + \
                             self.topic_loss_recon + \
                             self.beta * self.topic_loss_kl_bow + \
                             self.beta * self.topic_loss_kl_prob + \
                             self.config.lam_reg * self.loss_reg

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('enc') + tf.trainable_variables('dec')), \
                                lr=self.config.lr, global_step=self.global_step)

                self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                                    self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob)

                self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \
                    get_opt(self.loss_disc, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('disc')), lr=self.config.lr_disc)

            else:
                self.loss = self.loss_recon + \
                             self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \
                                                            - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \
                             self.beta * self.loss_kl_root + \
                             self.topic_loss_recon + \
                             self.beta * self.topic_loss_kl_bow + \
                             self.beta * self.topic_loss_kl_prob + \
                             self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \
                             self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \
                             self.config.lam_reg * self.loss_reg
                self.loss_disc = tf.constant(0, dtype=tf.float32)

                self.opt, self.grad_vars, self.clipped_grad_vars = \
                    get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.config.lr, global_step=self.global_step)
                self.opt_disc = tf.constant(0, dtype=tf.float32)

            # ------------------------------Evaluatiion------------------------------
            self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, self.loss_kl_topic_gmm_reverse, \
                self.loss_kl_root, self.loss_disc_topic, self.topic_loss_recon, self.topic_loss_kl_bow, self.topic_loss_kl_prob, self.loss_reg, tf.constant(0)]
            self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, self.loss_kl_topic_gmm_reverse, \
                self.loss_kl_root, self.loss_disc_topic, self.topic_loss_recon, self.topic_loss_kl_bow, self.topic_loss_kl_prob, self.loss_reg, self.loss_kl_topic_gmm]
            self.loss_sum = (self.loss_recon + self.loss_kl_prob + self.loss_kl_sent_gmm + self.loss_kl_root + self.loss_disc_topic + \
                                 self.topic_loss_recon + self.topic_loss_kl_bow + self.topic_loss_kl_prob) * self.n_sents