def decoder(mel_targets, encoder_output, scope="decoder", training=True, reuse=None): with tf.variable_scope(scope, reuse=reuse): decoder_cell = TacotronDecoderWrapper(unidirectional_LSTM(training, layers=hp.dec_LSTM_layers, size=hp.dec_LSTM_size), training) attention_decoder = AttentionWrapper( decoder_cell, LocationBasedAttention(hp.attention_size, encoder_output), #BahdanauAttention(hp.attention_size, encoder_output), alignment_history=True, output_attention=False) decoder_state = attention_decoder.zero_state(batch_size=hp.batch_size, dtype=tf.float32) projection = tf.tile([[0.0]], [hp.batch_size, hp.n_mels]) final_projection =tf.zeros([hp.batch_size, hp.T_y//hp.r, hp.n_mels], tf.float32) if hp.include_dones: LSTM_att =tf.zeros([hp.batch_size, hp.T_y//hp.r, hp.dec_LSTM_size*2], tf.float32) else: LSTM_att = 0 step = 0 def att_condition(step, projection, final_projection, decoder_state, mel_targets,LSTM_att): return step < hp.T_y//hp.r def att_body(step, projection, final_projection, decoder_state, mel_targets,LSTM_att): if training: if step == 0: projection, decoder_state, _, LSTM_next = attention_decoder.call(tf.tile([[0.0]], [hp.batch_size, hp.n_mels]), decoder_state) else: projection, decoder_state, _, LSTM_next = attention_decoder.call(mel_targets[:, step-1, :], decoder_state) else: projection, decoder_state, _, LSTM_next = attention_decoder.call(projection, decoder_state) fprojection = tf.expand_dims(projection,axis=1) final_projection = tf.concat([final_projection,fprojection],axis=1)[:,1:,:] if hp.include_dones: fLSTM_next = tf.expand_dims(LSTM_next,axis=1) LSTM_att = tf.concat([LSTM_att,fLSTM_next],axis=1)[:,1:,:] return ((step+1), projection, final_projection, decoder_state, mel_targets,LSTM_att) res_loop = tf.while_loop(att_condition, att_body, loop_vars=[step, projection, final_projection, decoder_state, mel_targets,LSTM_att], parallel_iterations=hp.parallel_iterations, swap_memory=False) final_projection = res_loop[2] final_decoder_state = res_loop[3] concat_LSTM_att = res_loop[5] step = res_loop[0] if hp.print_shapes: print(final_projection) with tf.variable_scope("postnet"): tensor = final_projection for i in range(hp.dec_postnet_layers): tensor = conv1d(tensor, filters=hp.dec_postnet_filters, kernel_size=hp.dec_postnet_size, activation=tf.nn.tanh if i<4 else None, training=training, dropout_rate=hp.dropout_rate, scope="decoder_conv_{}".format(i)) # (N, Tx, c) #tensor = tf.layers.dense(tensor,hp.n_mels) tensor = tf.contrib.layers.fully_connected(tensor, hp.n_mels, activation_fn=None, biases_initializer=tf.zeros_initializer()) if hp.print_shapes: print(tensor) mel_logits = final_projection + tensor if hp.print_shapes: print(mel_logits) if hp.include_dones: with tf.variable_scope("done_output"): done_output = fc_block(concat_LSTM_att, 2, training=training) done_output = tf.nn.sigmoid(done_output) if hp.print_shapes: print(done_output) else: done_output = None concat_LSTM_att = None return mel_logits, final_projection, done_output, final_decoder_state, concat_LSTM_att,step
ap.add_argument("--report-result", type=str, default=None, help="File where to save the results.") ap.add_argument("-s", "--sentences", nargs='*', type=int, default=None, help="Only use the specified sentences; 0-based") args = ap.parse_args() dependency_tree = Dependency(args.conll, args.tokens) bert_attns = AttentionWrapper(args.attentions, dependency_tree.wordpieces2tokens, args.sentences) if args.evaluate_only: if not args.json: raise ValueError( "JSON with head ensembles required in evaluate only mode!") with open(args.json, 'r') as inj: head_ensembles = json.load(inj) head_ensembles = { rl: HeadEnsemble.from_dict(**he_dict) for rl, he_dict in head_ensembles.items() } else: head_ensembles = dict()
ax.set_yticks(np.arange(len(sentence_tokens))) ax.set_xticks(np.arange(len(sentence_tokens))) ax.set_xticklabels(sentence_tokens, rotation=90) ax.set_yticklabels(sentence_tokens) ax.set_ylim(top=-0.5, bottom=len(sentence_tokens) - 0.5) plt.savefig(out_file, dpi=300, format='pdf') if __name__ == '__main__': ap = argparse.ArgumentParser() ap.add_argument("attentions", type=str, help="NPZ file with attentions") ap.add_argument("tokens", type=str, help="Labels (tokens) separated by spaces") ap.add_argument("conll", type=str, help="Conll file for head selection.") ap.add_argument("-layer_idcs", nargs="*", type=int, default=[5, 3], help = "layer indices to plot") ap.add_argument("-head_idcs", nargs="*", type=int, default=[4, 9], help="head indices to plot") ap.add_argument("-s", "--sentences", nargs='*', type=int, default=list(range(10)), help="Only use the specified sentences; 0-based") ap.add_argument("-vis-dir", type=str, default="../results", help="Directory where to save head visualizations") args = ap.parse_args() dependency_tree = Dependency(args.conll, args.tokens) bert_attns = AttentionWrapper(args.attentions, dependency_tree.wordpieces2tokens, args.sentences) for sent_idx, attn_mats in bert_attns: for l, h in zip(args.layer_idcs, args.head_idcs): out_file = os.path.join(args.vis_dir, f"L-{l}_H-{h}_sent-{sent_idx}.pdf") plot_head(attn_mats, dependency_tree.tokens[sent_idx], l, h, out_file)
def build(self): self.global_step = tf.Variable(0, name='global_step', trainable=False) self.softmax_temperature = tf.maximum( \ self.config.max_temperature-tf.cast(tf.divide(self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32), \ self.config.min_temperature) with tf.name_scope('t_variables'): self.sample = self.t_variables['sample'] self.batch_l = self.t_variables['batch_l'] self.doc_l = self.t_variables['doc_l'] self.sent_l = self.t_variables['sent_l'] self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.max_doc_l = tf.reduce_max(self.doc_l) self.max_sent_l = tf.reduce_max(self.sent_l) self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.mask_doc = tf.sequence_mask(self.doc_l, dtype=tf.float32) self.mask_sent = tf.sequence_mask(self.sent_l, dtype=tf.float32) mask_bow = np.zeros(self.config.n_vocab) mask_bow[self.config.bow_idxs] = 1. self.mask_bow = tf.constant(mask_bow, dtype=tf.float32) self.enc_keep_prob = self.t_variables['enc_keep_prob'] # ------------------------------Encoder ------------------------------ with tf.variable_scope('emb'): with tf.variable_scope('word', reuse=False): pad_embedding = tf.zeros([1, self.config.dim_emb], dtype=tf.float32) nonpad_embeddings = tf.get_variable('emb', [self.config.n_vocab-1, self.config.dim_emb], dtype=tf.float32, \ initializer=tf.contrib.layers.xavier_initializer()) self.embeddings = tf.concat([pad_embedding, nonpad_embeddings], 0) # n_vocab x dim_emb self.bow_embeddings = tf.nn.embedding_lookup( self.embeddings, self.config.bow_idxs) # dim_bow x dim_emb # get sentence embeddings self.enc_input_idxs = tf.one_hot( self.t_variables['enc_input_idxs'], depth=self.config.n_vocab ) # batch_l x max_doc_l x max_sent_l x n_vocab self.enc_inputs = tf.tensordot( self.enc_input_idxs, self.embeddings, axes=[[-1], [0]]) # batch_l x max_doc_l x max_sent_l x dim_emb with tf.variable_scope('sent', reuse=False): self.sent_outputs, self.sent_state = \ encode_inputs(self, enc_inputs=self.enc_inputs, sent_l=self.sent_l) # batch_l x max_doc_l x dim_hidden*2 with tf.variable_scope('enc'): # get sentence latents with tf.variable_scope('latents_sent', reuse=False): self.w_topic_posterior = tf.get_variable( 'topic_posterior/kernel', [ self.config.n_topic, self.sent_state.shape[-1], self.config.dim_hidden ], dtype=tf.float32) self.b_topic_posterior = tf.get_variable( 'topic_posterior/bias', [1, self.config.n_topic, self.config.dim_hidden], dtype=tf.float32) self.topic_state = tf.reduce_sum( self.sent_state * tf.expand_dims(self.mask_doc, -1), -2) / tf.reduce_sum(self.mask_doc, -1, keepdims=True) self.hidden_topic_posterior = tf.tensordot( self.topic_state, self.w_topic_posterior, axes=[[1], [1]] ) + self.b_topic_posterior # batch_l x n_topic x dim_hidden # ------------------------------Discriminator------------------------------ with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=False): # encode by TSNTM self.probs_sent_topic_posterior, _, _ = \ encode_gsm_probs_topic_posterior(self, self.hidden_topic_posterior.get_shape()[-1], self.hidden_topic_posterior, self.mask_doc, self.config) # batch_l x max_doc_l x n_topic with tf.name_scope('latents_topic'): # get topic sentence posterior distribution for each document self.probs_topic_posterior = tf.reduce_sum( self.probs_sent_topic_posterior, 1) # batch_l x n_topic self.means_sent_topic_posterior = tf.multiply(tf.expand_dims(self.probs_sent_topic_posterior, -1), \ tf.expand_dims(self.means_sent_posterior, -2)) # batch_l x max_doc_l x n_topic x dim_latent self.means_topic_posterior_ = tf.reduce_sum(self.means_sent_topic_posterior, 1) / \ tf.expand_dims(self.probs_topic_posterior, -1) # batch_l x n_topic x dim_latent self.means_topic_posterior = tf_clip_means( self.means_topic_posterior_, self.probs_topic_posterior) diffs_sent_topic_posterior = tf.expand_dims(self.means_sent_posterior, 2) - \ tf.expand_dims(self.means_topic_posterior, 1) # batch_l x max_doc_l x n_topic x dim_latent self.covs_sent_topic_posterior = tf.multiply(tf.expand_dims(tf.expand_dims(self.probs_sent_topic_posterior, -1), -1), \ tf.matrix_diag(tf.expand_dims(tf.exp(self.logvars_sent_posterior), 2)) + tf.matmul(tf.expand_dims(diffs_sent_topic_posterior, -1), \ tf.expand_dims(diffs_sent_topic_posterior, -2))) # batch_l x max_doc_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior_ = tf.reduce_sum(self.covs_sent_topic_posterior, 1) / \ tf.expand_dims(tf.expand_dims(self.probs_topic_posterior, -1), -1) # batch_l x n_topic x dim_latent x dim_latent self.covs_topic_posterior = tf_clip_covs( self.covs_topic_posterior_, self.probs_topic_posterior) self.latents_topic_posterior = sample_latents_fullcov(self.means_topic_posterior, self.covs_topic_posterior, \ seed=self.config.seed, sample=self.sample) self.means_topic_prior = tf.zeros( [ self.batch_l, self.config.n_topic, self.config.dim_latent ], dtype=tf.float32) # batch_l x n_topic x dim_latent self.covs_topic_prior = tf.eye( self.config.dim_latent, batch_shape=[self.batch_l, self.config.n_topic], dtype=tf.float32) * self.config.cov_root # ------------------------------Decoder---------------------------------- with tf.variable_scope('dec'): # decode for training sent with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=False): self.dec_cell = tf.contrib.rnn.GRUCell(self.config.dim_hidden) self.dec_cell = tf.contrib.rnn.DropoutWrapper( self.dec_cell, output_keep_prob=self.t_variables['dec_keep_prob']) self.dec_sent_cell = self.dec_cell self.latent_hidden_layer = tf.layers.Dense( units=self.config.dim_hidden, activation=tf.nn.relu, name='latent_hidden_linear') self.dec_sent_initial_state = self.latent_hidden_layer( self.latents_sent_posterior ) # batch_l x max_doc_l x dim_hidden self.output_layer = tf.layers.Dense(self.config.n_vocab, use_bias=False, name='out') if self.config.attention: self.sent_outputs_flat = tf.reshape( self.sent_outputs, [ self.batch_l * self.max_doc_l, self.max_sent_l, self.config.dim_hidden * 2 ]) self.att_sent_l_flat = tf.reshape( tf.maximum(self.sent_l, 1), [self.batch_l * self.max_doc_l]) self.att_sent_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.config.dim_hidden, memory=self.sent_outputs_flat, \ memory_sequence_length=self.att_sent_l_flat) self.att_cell = tf.contrib.seq2seq.AttentionWrapper( self.dec_cell, attention_mechanism=self.att_sent_mechanism, attention_layer_size=self.config.dim_hidden) self.dec_sent_cell = self.att_cell # teacher forcing self.dec_input_idxs = self.t_variables[ 'dec_input_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_inputs = tf.nn.embedding_lookup( self.embeddings, self.dec_input_idxs ) # batch_l x max_doc_l x max_dec_sent_l x dim_emb # output_sent_l == dec_sent_l self.output_logits_flat, self.output_sent_l_flat = decode_output_logits_flat( self, dec_cell=self.dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, dec_inputs=self.dec_inputs, dec_sent_l=self.dec_sent_l, latents_input=self.latents_sent_posterior ) # batch_l*max_doc_l x max_output_sent_l x n_vocab self.output_sent_l = tf.reshape(self.output_sent_l_flat, [self.batch_l, self.max_doc_l]) self.max_output_sent_l = tf.reduce_max(self.output_sent_l) self.output_logits = tf.reshape(self.output_logits_flat, \ [self.batch_l, self.max_doc_l, self.max_output_sent_l, self.config.n_vocab], name='output_logits') if self.config.disc_gumbel: self.output_input_idxs = sample_gumbels( self.output_logits, self.softmax_temperature, self.config.seed, self.sample ) # batch_l x max_doc_l x max_output_sent_l x n_vocab else: self.output_input_idxs = self.output_logits # decode for training topic probs with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.dec_topic_cell = self.dec_cell if self.config.attention: self.topic_outputs_flat = tf.contrib.seq2seq.tile_batch(tf.reshape(self.sent_outputs, \ [self.batch_l, self.max_doc_l*self.max_sent_l, self.sent_outputs.get_shape()[-1]]), \ multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l x dim_hidden*2 self.score_mask = tf.contrib.seq2seq.tile_batch(tf.reshape(tf.sequence_mask(self.sent_l), \ [self.batch_l, self.max_doc_l*self.max_sent_l]), multiplier=self.config.n_topic) # batch_l*n_topic x max_doc_l*max_sent_l self.hier_score = tf.reshape(tf.transpose(self.probs_sent_topic_posterior, [0, 2, 1]), \ [self.batch_l*self.config.n_topic, self.max_doc_l]) # batch_l*n_topic x max_doc_l self.att_topic_mechanism = HierarchicalAttention( num_units=self.config.dim_hidden, memory=self.topic_outputs_flat, score_mask=self.score_mask, hier_score=self.hier_score) self.att_topic_cell = AttentionWrapper( self.dec_cell, attention_mechanism=self.att_topic_mechanism, attention_layer_size=self.config.dim_hidden) self.dec_topic_cell = self.att_topic_cell if not self.config.disc_mean: self.dec_topic_initial_state = self.latent_hidden_layer( self.latents_topic_posterior) dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.latents_topic_posterior ) # batch_l*max_doc_l x max_summary_sent_l x n_vocab else: self.dec_topic_initial_state = self.latent_hidden_layer( self.means_topic_posterior) dec_topic_outputs, self.summary_sent_l_flat = decode_output_sample_flat( self, dec_cell=self.dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, softmax_temperature=self.softmax_temperature, sample=self.sample, latents_input=self.means_topic_posterior ) # batch_l*max_doc_l x max_summary_sent_l x n_vocab self.summary_sent_l = tf.reshape( self.summary_sent_l_flat, [self.batch_l, self.config.n_topic]) self.max_summary_sent_l = tf.reduce_max(self.summary_sent_l) if self.config.disc_gumbel: summary_input_idxs_flat = dec_topic_outputs.sample_id else: summary_input_idxs_flat = dec_topic_outputs.rnn_output self.summary_input_idxs = tf.reshape(summary_input_idxs_flat, \ [self.batch_l, self.config.n_topic, self.max_summary_sent_l, self.config.n_vocab], name='summary_input_idxs') # re-encode topic sentence outputs self.summary_inputs = tf.tensordot( self.summary_input_idxs, self.embeddings, axes=[[-1], [ 0 ]]) # batch_l x n_topic x max_summary_sent_l x dim_emb self.summary_input_sent_l = self.summary_sent_l - 1 # to remove EOS self.mask_summary_sent = tf.sequence_mask(self.summary_input_sent_l, \ maxlen=self.max_summary_sent_l, dtype=tf.float32) # batch_l x n_topic x max_summary_sent_l self.mask_summary_doc = tf.ones( [self.batch_l, self.config.n_topic], dtype=tf.float32) # beam decode for inference of original sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.beam_dec_sent_cell = self.dec_cell if self.config.attention: self.beam_sent_outputs_flat = tf.contrib.seq2seq.tile_batch( self.sent_outputs_flat, multiplier=self.config.beam_width) self.beam_att_sent_l_flat = tf.contrib.seq2seq.tile_batch( self.att_sent_l_flat, multiplier=self.config.beam_width) self.beam_att_sent_mechanism = tf.contrib.seq2seq.LuongAttention( num_units=self.config.dim_hidden, memory=self.beam_sent_outputs_flat, memory_sequence_length=self.beam_att_sent_l_flat) self.beam_dec_sent_cell = tf.contrib.seq2seq.AttentionWrapper( self.beam_dec_sent_cell, attention_mechanism=self.beam_att_sent_mechanism, attention_layer_size=self.config.dim_hidden) # infer original sentences self.beam_output_idxs, _, _ = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_sent_cell, dec_initial_state=self.dec_sent_initial_state, latents_input=self.means_sent_posterior, name='beam_output_idxs') # beam decode for inference of topic sentences with tf.variable_scope( 'sent', initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32, reuse=True): self.beam_dec_topic_cell = self.dec_cell if self.config.attention: self.beam_topic_outputs_flat = tf.contrib.seq2seq.tile_batch( self.topic_outputs_flat, multiplier=self.config.beam_width) self.beam_score_mask = tf.contrib.seq2seq.tile_batch( self.score_mask, multiplier=self.config.beam_width) self.beam_hier_score = tf.contrib.seq2seq.tile_batch( self.hier_score, multiplier=self.config.beam_width) self.beam_att_topic_mechanism = HierarchicalAttention( num_units=self.config.dim_hidden, memory=self.beam_topic_outputs_flat, score_mask=self.beam_score_mask, hier_score=self.beam_hier_score) self.beam_dec_topic_cell = AttentionWrapper( self.beam_dec_topic_cell, attention_mechanism=self.beam_att_topic_mechanism, attention_layer_size=self.config.dim_hidden) # infer topic sentences self.beam_summary_idxs, _, _ = decode_beam_output_token_idxs( self, beam_dec_cell=self.beam_dec_topic_cell, dec_initial_state=self.dec_topic_initial_state, latents_input=self.latents_topic_posterior, name='beam_summary_idxs') self.beam_mask_summary_sent = tf.logical_not(tf.equal(self.beam_summary_idxs, \ self.config.EOS_IDX)) # batch_l x n_topic x max_summary_sent_l self.beam_summary_input_sent_l = tf.reduce_sum( tf.cast(self.beam_mask_summary_sent, tf.int32), -1) # batch_l x n_topic beam_summary_soft_idxs = tf.one_hot(tf.where(self.beam_mask_summary_sent, \ self.beam_summary_idxs, tf.zeros_like(self.beam_summary_idxs)), depth=self.config.n_vocab) self.beam_summary_inputs = tf.tensordot(beam_summary_soft_idxs, \ self.embeddings, [[-1], [0]]) # batch_l x n_topic x max_beam_summary_sent_l x dim_emb # ------------------------------Discriminator------------------------------ # encode by MLP if self.config.enc == 'mlp': with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=True): self.summary_state = encode_states(self, enc_inputs=self.summary_inputs, mask_sent=self.mask_summary_sent, \ enc_keep_prob=self.enc_keep_prob, config=self.config) # batch_l x n_topic x dim_hidden elif self.config.enc == 'bow': with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=True): self.bow_summary_input_idxs = tf.multiply( self.summary_input_idxs, self.mask_bow) self.bow_summary_inputs = tf.tensordot( self.bow_summary_input_idxs, self.embeddings, axes=[[-1], [0] ]) # batch_l x max_doc_l x max_sent_l x dim_emb self.mask_summary_bow = tf.reduce_sum( self.bow_summary_input_idxs, -1) self.summary_state = encode_states(self, enc_inputs=self.bow_summary_inputs, mask_sent=self.mask_summary_bow, \ enc_keep_prob=self.enc_keep_prob, config=self.config) # batch_l x max_doc_l x dim_hidden elif self.config.enc == 'rnn': with tf.variable_scope('emb'): with tf.variable_scope('sent', reuse=True): _, self.summary_state = encode_inputs( self, enc_inputs=self.summary_inputs, sent_l=self.summary_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 _, self.beam_summary_state = encode_inputs( self, enc_inputs=self.beam_summary_inputs, sent_l=self.beam_summary_input_sent_l ) # batch_l x max_doc_l x dim_hidden*2 with tf.variable_scope('disc'): with tf.variable_scope('prob_topic', reuse=True): self.probs_summary_topic_posterior, _, _ = \ encode_gsm_probs_topic_posterior(self, self.summary_state.get_shape()[-1], self.summary_state, self.mask_summary_doc, self.config) self.logits_summary_topic_posterior_ = tf_log( tf.matrix_diag_part(self.probs_summary_topic_posterior) ) # batch_l x n_topic self.logits_summary_topic_posterior = tf_clip_vals( self.logits_summary_topic_posterior_, self.probs_topic_posterior) # ------------------------------Optimizer and Loss------------------------------ with tf.name_scope('opt'): partition_doc = tf.cast(self.mask_doc, dtype=tf.int32) self.n_sents = tf.cast(tf.reduce_sum(self.doc_l), dtype=tf.float32) self.n_tokens = tf.reduce_sum(self.dec_sent_l) # ------------------------------Reconstruction Loss of Language Model------------------------------ # target and mask self.dec_target_idxs = self.t_variables[ 'dec_target_idxs'] # batch_l x max_doc_l x max_dec_sent_l self.dec_sent_l = self.t_variables[ 'dec_sent_l'] # batch_l x max_doc_l self.max_dec_sent_l = tf.reduce_max( self.dec_sent_l) # = max_sent_l + 1 self.dec_mask_sent = tf.sequence_mask(self.dec_sent_l, maxlen=self.max_dec_sent_l, dtype=tf.float32) self.dec_target_idxs_flat = tf.reshape( self.dec_target_idxs, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) self.dec_mask_sent_flat = tf.reshape( self.dec_mask_sent, [self.batch_l * self.max_doc_l, self.max_dec_sent_l]) # nll for each token (summed over sentence) self.recon_max_sent_l = tf.minimum( self.max_dec_sent_l, self.max_output_sent_l) if self.config.sample else None losses_recon_flat = tf.reduce_sum( tf.contrib.seq2seq.sequence_loss( self.output_logits_flat[:, :self.recon_max_sent_l, :], self.dec_target_idxs_flat[:, :self.recon_max_sent_l], self.dec_mask_sent_flat[:, :self.recon_max_sent_l], average_across_timesteps=False, average_across_batch=False), -1) # batch_l*max_doc_l self.losses_recon = tf.reshape(losses_recon_flat, [self.batch_l, self.max_doc_l]) self.loss_recon = tf.reduce_mean( tf.dynamic_partition( self.losses_recon, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Topic Probability Distribution------------------------------ if self.config.topic_model: self.probs_sent_topic_prior = tf.expand_dims( self.probs_doc_topic_posterior, 1) # batch_l x 1 x n_topic else: self.probs_sent_topic_prior = tf.ones_like(self.probs_sent_topic_posterior, dtype=tf.float32) / \ self.config.n_topic # batch_l x max_doc_l x n_topic, uniform distribution over topics self.losses_kl_prob = tf.reduce_sum(tf.multiply(self.probs_sent_topic_posterior, \ (tf_log(self.probs_sent_topic_posterior)-tf_log(self.probs_sent_topic_prior))), -1) self.loss_kl_prob = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_prob, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Sentence Latents Distribution------------------------------ self.losses_kl_sent_gauss = compute_kl_losses_sent_gauss( self ) # batch_l x max_doc_l x n_topic, sum over latent dimension self.losses_kl_sent_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_sent_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_sent_gmm = tf.reduce_mean( tf.dynamic_partition( self.losses_kl_sent_gmm, partition_doc, num_partitions=2)[1]) # average over doc x batch # ------------------------------KL divergence Loss of Topic Latents Distribution------------------------------ if self.config.reverse_kl: self.losses_kl_topic_pairs_gauss = compute_kl_losses_topic_paris_gauss( self) self.losses_kl_topic_gauss_reverse = tf.reduce_sum(self.losses_kl_topic_pairs_gauss * self.config.mask_tree[None, None, :, :], -1) / \ np.maximum(np.sum(self.config.mask_tree[None, None, :, :], -1), 1) # batch_l x 1 x n_topic, mean over other child topics self.losses_kl_topic_gmm_reverse = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_topic_gauss_reverse), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_topic_gmm_reverse = tf.reduce_mean( tf.dynamic_partition(self.losses_kl_topic_gmm_reverse, partition_doc, num_partitions=2)[1]) else: self.loss_kl_topic_gmm_reverse = tf.constant(0., dtype=tf.float32) # for monitor self.losses_kl_topic_gauss = compute_kl_losses_topic_gauss( self) # batch_l x 1 x n_topic, sum over latent dimension self.losses_kl_topic_gmm = tf.reduce_sum( tf.multiply(self.probs_sent_topic_posterior, self.losses_kl_topic_gauss), -1) # batch_l x max_doc_l, sum over topics self.loss_kl_topic_gmm = tf.reduce_mean( tf.dynamic_partition(self.losses_kl_topic_gmm, partition_doc, num_partitions=2)[1]) # ------------------------------KL divergence Loss of Root State Distribution------------------------------ if self.config.prior_root: self.losses_kl_root = compute_kl_losses( self.means_state_root_posterior, self.logvars_state_root_posterior) # batch_l x max_doc_l self.loss_kl_root = tf.reduce_sum( self.losses_kl_root) / tf.cast( tf.reduce_sum(self.doc_l), dtype=tf.float32) # average over doc x batch else: self.loss_kl_root = tf.constant(0, dtype=tf.float32) # ------------------------------Discriminator Loss------------------------------ if self.config.disc_topic: self.losses_disc_topic = -tf.reduce_sum( self.logits_summary_topic_posterior, -1) # batch_l, sum over topic self.loss_disc_topic = tf.reduce_sum( self.losses_disc_topic ) / self.n_sents # average over doc x batch else: self.loss_disc_topic = tf.constant(0, dtype=tf.float32) # ------------------------------Loss of Topic Model------------------------------ if self.config.topic_model: # recon self.topic_losses_recon = -tf.reduce_sum( tf.multiply(self.t_variables['doc_bows'], self.logits_bow), -1) # n_batch, sum over n_bow self.topic_loss_recon = tf.reduce_mean( self.topic_losses_recon) # average over doc x batch # kl_bow self.means_topic_bow_prior = tf.squeeze(get_params_topic_prior(self, tf.expand_dims(self.means_topic_bow_posterior, 0), \ tf.zeros([1, self.config.dim_latent], dtype=tf.float32)), 0) # n_topic x dim_latent self.logvars_topic_bow_prior = tf.squeeze(get_params_topic_prior(self, tf.expand_dims(self.logvars_topic_bow_posterior, 0), \ tf.zeros([1, self.config.dim_latent], dtype=tf.float32)), 0) # n_topic x dim_latent self.topic_losses_kl_bow = compute_kl_losses(self.means_topic_bow_posterior, self.logvars_topic_bow_posterior, \ means_prior=self.means_topic_bow_prior, logvars_prior=self.logvars_topic_bow_prior) # n_topic self.topic_loss_kl_bow = tf.reduce_mean( self.topic_losses_kl_bow) # average over doc x batch # kl_prob self.topic_losses_kl_prob = compute_kl_losses( self.means_probs_doc_topic_posterior, self.logvars_probs_doc_topic_posterior) # batch_l self.topic_loss_kl_prob = tf.reduce_mean( self.topic_losses_kl_prob) # average over doc x batch else: self.topic_loss_recon = tf.constant(0, dtype=tf.float32) self.topic_loss_kl_bow = tf.constant(0, dtype=tf.float32) self.topic_loss_kl_prob = tf.constant(0, dtype=tf.float32) # ------------------------------Topic Regularization Loss------------------------------ if self.config.reg != '': if self.config.reg == 'mean': self.topic_dots = self.get_topic_dots( self.means_topic_posterior ) # batch_l x n_topic-1 x n_topic-1 elif self.config.reg == 'bow': self.topic_dots = self.get_topic_dots( tf.expand_dims( self.topic_bow, 0)) # batch_l(=1) x n_topic-1 x n_topic-1 self.losses_reg = tf.reduce_sum(tf.square(self.topic_dots - tf.eye(len(self.config.all_child_idxs))) * self.config.mask_tree_reg, [1, 2])\ / tf.reduce_sum(self.config.mask_tree_reg) # batch_l self.loss_reg = tf.reduce_mean( self.losses_reg) # average over batch else: self.loss_reg = tf.constant(0, dtype=tf.float32) # ------------------------------Optimizer------------------------------ if self.config.anneal == 'linear': self.tau = tf.cast(tf.divide( self.global_step, tf.constant(self.config.linear_steps)), dtype=tf.float32) self.beta = tf.minimum(1., self.config.beta_init + self.tau) elif self.config.anneal == 'cycle': self.tau = tf.cast(tf.divide( tf.mod(self.global_step, tf.constant(self.config.cycle_steps)), tf.constant(self.config.cycle_steps)), dtype=tf.float32) self.beta = tf.minimum( 1., self.config.beta_init + self.tau / (1. - self.config.r_cycle)) else: self.beta = tf.constant(1.) self.beta_disc = self.beta if self.config.beta_disc else tf.constant( 1.) def get_opt(loss, var_list, lr, global_step=None): if self.config.opt == 'adam': Optimizer = tf.train.AdamOptimizer elif self.config.opt == 'adagrad': Optimizer = tf.train.AdagradOptimizer optimizer = Optimizer(lr) grad_vars = optimizer.compute_gradients(loss=loss, var_list=var_list) clipped_grad_vars = [ (tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in grad_vars if grad is not None ] opt = optimizer.apply_gradients(clipped_grad_vars, global_step=global_step) return opt, grad_vars, clipped_grad_vars # ------------------------------Loss Setting------------------------------ if self.config.turn: self.loss = self.loss_recon + \ self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \ - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \ self.beta * self.loss_kl_root + \ self.topic_loss_recon + \ self.beta * self.topic_loss_kl_bow + \ self.beta * self.topic_loss_kl_prob + \ self.config.lam_reg * self.loss_reg self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('enc') + tf.trainable_variables('dec')), \ lr=self.config.lr, global_step=self.global_step) self.loss_disc = self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) self.opt_disc, self.grad_vars_disc, self.clipped_grad_vars_disc = \ get_opt(self.loss_disc, var_list=list(tf.trainable_variables('emb') + tf.trainable_variables('disc')), lr=self.config.lr_disc) else: self.loss = self.loss_recon + \ self.beta * tf.maximum(tf.maximum(self.loss_kl_sent_gmm, self.config.capacity_gmm) \ - self.loss_kl_topic_gmm_reverse, self.config.margin_gmm) + \ self.beta * self.loss_kl_root + \ self.topic_loss_recon + \ self.beta * self.topic_loss_kl_bow + \ self.beta * self.topic_loss_kl_prob + \ self.beta_disc * self.config.lam_disc * self.loss_disc_topic + \ self.beta * tf.maximum(self.loss_kl_prob, self.config.capacity_prob) + \ self.config.lam_reg * self.loss_reg self.loss_disc = tf.constant(0, dtype=tf.float32) self.opt, self.grad_vars, self.clipped_grad_vars = \ get_opt(self.loss, var_list=tf.trainable_variables(), lr=self.config.lr, global_step=self.global_step) self.opt_disc = tf.constant(0, dtype=tf.float32) # ------------------------------Evaluatiion------------------------------ self.loss_list_train = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, self.loss_kl_topic_gmm_reverse, \ self.loss_kl_root, self.loss_disc_topic, self.topic_loss_recon, self.topic_loss_kl_bow, self.topic_loss_kl_prob, self.loss_reg, tf.constant(0)] self.loss_list_eval = [self.loss, self.loss_disc, self.loss_recon, self.loss_kl_prob, self.loss_kl_sent_gmm, self.loss_kl_topic_gmm_reverse, \ self.loss_kl_root, self.loss_disc_topic, self.topic_loss_recon, self.topic_loss_kl_bow, self.topic_loss_kl_prob, self.loss_reg, self.loss_kl_topic_gmm] self.loss_sum = (self.loss_recon + self.loss_kl_prob + self.loss_kl_sent_gmm + self.loss_kl_root + self.loss_disc_topic + \ self.topic_loss_recon + self.topic_loss_kl_bow + self.topic_loss_kl_prob) * self.n_sents