def discriminator_dag_supervised(latent, dag, dag_bw, dag_feats, sequence_length, params, idx, weights_regularizer=None, is_training=True): # latent (N, L, D) with tf.variable_scope('discriminator'): h = tf.concat([latent, dag_feats], axis=-1) with tf.variable_scope("upward"): h = message_passing(latent=h, dag_bw=dag_bw, params=params, fully_connected_fn=sn_fully_connected, weights_regularizer=weights_regularizer, hidden_depth=params.discriminator_layers, dim=params.discriminator_dim) with tf.variable_scope("downward"): h = message_passing(latent=h, dag_bw=dag, params=params, fully_connected_fn=sn_fully_connected, weights_regularizer=weights_regularizer, hidden_depth=params.discriminator_layers, dim=params.discriminator_dim) with tf.variable_scope('output_mlp'): if params.lstm_output_discriminator: h, _ = lstm(x=h, num_units=params.decoder_dim, bidirectional=True, num_layers=params.decoder_layers, sequence_lengths=sequence_length) else: for i in range(params.discriminator_layers): h = sn_fully_connected( inputs=h, activation_fn=tf.nn.leaky_relu, weights_regularizer=weights_regularizer, num_outputs=params.discriminator_dim, scope='discriminator_output_{}'.format(i)) logits = sn_fully_connected( inputs=h, num_outputs=1, activation_fn=None, scope='discriminator_logits', weights_regularizer=weights_regularizer) # (N,L,1) logits = tf.squeeze(logits, axis=-1) # (N, L) logits_values = tf.gather_nd(params=logits, indices=idx) # (X,) sparse_logits = tf.SparseTensor(values=logits_values, indices=tf.cast(idx, tf.int64), dense_shape=tf.cast( tf.shape(logits), tf.int64)) sparse_logits = tf.sparse_reorder(sparse_logits) dis_values = tf.sparse_reduce_sum( sp_input=sparse_logits, axis=-1) / tf.cast( sequence_length, tf.float32) # (n,) return dis_values # (n,)
def discriminator_dag_supervised( latent, dag, dag_bw, params, idx, tags, tag_size, weights_regularizer=None, is_training=True): # latent (N, L, D) with tf.variable_scope('decoder'): with tf.variable_scope('tag_embedding'): tag_embeddings = sn_kernel( shape=[tag_size, params.decoder_dim], scope="tag_embeddings" ) h_tags = tf.nn.embedding_lookup(params=tag_embeddings, ids=tags) h = tf.concat([latent, h_tags], axis=-1) with tf.variable_scope("forward"): h = message_passing( latent=h, dag_bw=dag_bw, params=params, fully_connected_fn=sn_fully_connected ) with tf.variable_scope("backward"): h = message_passing( latent=h, dag_bw=dag, params=params, fully_connected_fn=sn_fully_connected ) with tf.variable_scope('output_mlp'): h = sn_fully_connected( inputs=h, activation_fn=tf.nn.leaky_relu, num_outputs=params.decoder_dim, scope='output_1', weights_regularizer=weights_regularizer ) logits = sn_fully_connected( inputs=h, num_outputs=1, activation_fn=None, scope='output_2', weights_regularizer=weights_regularizer ) # (N,L,1) logits = tf.squeeze(logits, axis=-1) # (N, L) logits_values = tf.gather_nd( params=logits, indices=idx ) # (X,) sparse_logits = tf.SparseTensor( values=logits_values, indices=tf.cast(idx, tf.int64), dense_shape=tf.cast(tf.shape(logits), tf.int64) ) sparse_logits = tf.sparse_reorder(sparse_logits) dis_values = tf.sparse_reduce_sum( sp_input=sparse_logits, axis=-1 ) # (n,) return dis_values # (n,)
def decoder_dag_supervised(latent, dag, dag_bw, sequence_length, dag_feats, vocab_size, params, weights_regularizer=None, is_training=True): # latent (N, L, D) with tf.variable_scope('decoder'): h = tf.concat([latent, dag_feats], axis=-1) with tf.variable_scope("upward"): h = message_passing(latent=h, dag_bw=dag_bw, params=params, dim=params.decoder_dim, hidden_depth=params.decoder_layers, weights_regularizer=weights_regularizer) with tf.variable_scope("downward"): h = message_passing(latent=h, dag_bw=dag, params=params, dim=params.decoder_dim, hidden_depth=params.decoder_layers, weights_regularizer=weights_regularizer) with tf.variable_scope('output_mlp'): if params.lstm_output: h, _ = lstm(x=h, num_units=params.decoder_dim, bidirectional=True, num_layers=params.decoder_layers, sequence_lengths=sequence_length) else: for i in range(params.decoder_layers): h = slim.fully_connected( inputs=h, activation_fn=tf.nn.leaky_relu, weights_regularizer=weights_regularizer, num_outputs=params.decoder_dim, scope='decoder_output_{}'.format(i)) logits = slim.fully_connected( inputs=h, num_outputs=vocab_size, activation_fn=None, scope='decoder_output_logits', weights_regularizer=weights_regularizer, #biases_initializer=tf.initializers.constant( # value=get_bias(smoothing=params.bias_smoothing), # verify_shape=True) ) # (N,L,V) return logits
def encoder_dag(dag, dag_bw, dag_feats, text, vocab_size, params, weights_regularizer=None): with tf.variable_scope("encoder"): text_embeddings = tf.get_variable( dtype=tf.float32, name="text_embeddings", shape=[vocab_size, params.encoder_dim], initializer=tf.initializers.truncated_normal( stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32)))) h_text = tf.nn.embedding_lookup(params=text_embeddings, ids=text) # (N, L, D) h = tf.concat([h_text, dag_feats], axis=-1) with tf.variable_scope("upward"): h = message_passing(latent=h, dag_bw=dag_bw, params=params, dim=params.encoder_dim, hidden_depth=params.encoder_layers, weights_regularizer=weights_regularizer) with tf.variable_scope("downward"): h = message_passing(latent=h, dag_bw=dag, params=params, dim=params.encoder_dim, hidden_depth=params.encoder_layers, weights_regularizer=weights_regularizer) for i in range(params.encoder_layers): h = slim.fully_connected(inputs=h, activation_fn=tf.nn.leaky_relu, num_outputs=params.encoder_dim, scope='encoder_output_{}'.format(i), weights_regularizer=weights_regularizer) mu = slim.fully_connected(inputs=h, num_outputs=params.latent_dim, activation_fn=None, scope='encoder_mlp_mu', weights_regularizer=weights_regularizer) logsigma = slim.fully_connected( inputs=h, num_outputs=params.latent_dim, activation_fn=None, scope='encoder_mlp_logsigma', weights_regularizer=weights_regularizer) return mu, logsigma
def discriminator_dag(latent, dag, dag_bw, sequence_length, params, weights_regularizer=None, is_training=True): # latent (N, L, D) N = tf.shape(latent)[0] L = tf.shape(latent)[1] with tf.variable_scope('discriminator'): h_linspace = tf.linspace(start=0, stop=L, num=L) h_linspace = tf.tile(tf.expand_dims(h_linspace, 0), [N, 1]) h_linspace = h_linspace / tf.expand_dims(sequence_length, axis=1) h = tf.concat([latent, tf.expand_dims(h_linspace, -1)], axis=-1) with tf.variable_scope("upward"): h = message_passing( latent=h, dag_bw=dag_bw, params=params, dim=params.discriminator_dim, hidden_depth=params.discriminator_layers, weights_regularizer=weights_regularizer ) with tf.variable_scope("downward"): h = message_passing( latent=h, dag_bw=dag, params=params, dim=params.discriminator_dim, hidden_depth=params.discriminator_layers, weights_regularizer=weights_regularizer ) with tf.variable_scope('output_mlp'): for i in range(params.discriminator_layers): h = slim.fully_connected( inputs=h, activation_fn=tf.nn.leaky_relu, weights_regularizer=weights_regularizer, num_outputs=params.discriminator_dim, scope='discriminator_output_{}'.format(i) ) logits = slim.fully_connected( inputs=h, num_outputs=1, activation_fn=None, scope='discriminator_output_logits', weights_regularizer=weights_regularizer, biases_initializer=tf.initializers.constant( value=get_bias(smoothing=params.bias_smoothing), verify_shape=True) ) # (N,L,1) return logits
def encoder_dag(dag, dag_bw, text, tags, vocab_size, tags_size, params, weights_regularizer=None): with tf.variable_scope("encoder"): text_embeddings = tf.get_variable( dtype=tf.float32, name="text_embeddings", shape=[vocab_size, params.encoder_dim], initializer=tf.initializers.truncated_normal( stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32)))) tag_embeddings = tf.get_variable( dtype=tf.float32, name="tag_embeddings", shape=[tags_size, params.encoder_dim], initializer=tf.initializers.truncated_normal( stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32)))) h_text = tf.nn.embedding_lookup(params=text_embeddings, ids=text) # (L, N, D) h_tags = tf.nn.embedding_lookup(params=tag_embeddings, ids=tags) # (L, N, D) h = tf.concat([h_text, h_tags], axis=-1) with tf.variable_scope("forward"): h = message_passing(latent=h, dag_bw=dag_bw, params=params) with tf.variable_scope("backward"): h = message_passing(latent=h, dag_bw=dag, params=params) mu = slim.fully_connected(inputs=h, num_outputs=params.latent_dim, activation_fn=None, scope='encoder_mlp_mu', weights_regularizer=weights_regularizer) logsigma = slim.fully_connected( inputs=h, num_outputs=params.latent_dim, activation_fn=None, scope='encoder_mlp_logsigma', weights_regularizer=weights_regularizer) return mu, logsigma
def vae_decoder_dag_supervised(latent, dag, dag_bw, vocab_size, params, tags, tag_size, weights_regularizer=None, is_training=True): # latent (N, L, D) with tf.variable_scope('decoder'): tag_embeddings = tf.get_variable( dtype=tf.float32, name="tag_embeddings", shape=[tag_size, params.decoder_dim], initializer=tf.initializers.truncated_normal( stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32)))) h_tags = tf.nn.embedding_lookup(params=tag_embeddings, ids=tags) h = tf.concat([latent, h_tags], axis=-1) with tf.variable_scope("forward"): h = message_passing(latent=h, dag_bw=dag_bw, params=params) with tf.variable_scope("backward"): h = message_passing(latent=h, dag_bw=dag, params=params) with tf.variable_scope('output_mlp'): h = slim.fully_connected(inputs=h, activation_fn=tf.nn.leaky_relu, num_outputs=params.decoder_dim, scope='output_1') logits = slim.fully_connected( inputs=h, num_outputs=vocab_size, activation_fn=None, scope='output_2', weights_regularizer=weights_regularizer, biases_initializer=tf.initializers.constant( value=get_bias(smoothing=params.bias_smoothing), verify_shape=True)) # (N,L,V) return logits
def vae_decoder_dag(latent, sequence_lengths, vocab_size, params, n, weights_regularizer=None, is_training=True): # latent (N, L, D) with tf.variable_scope('decoder'): latent_processed = process_latent(latent=latent, sequence_lengths=sequence_lengths, params=params) dag, penalty = make_dag(latent=latent_processed, sequence_lengths=sequence_lengths, params=params) dag_bw = tf.transpose(dag, (0, 2, 1)) hidden = message_passing(latent=latent, dag_bw=dag, params=params) hidden = message_passing(latent=hidden, dag_bw=dag_bw, params=params) # hidden (N, L, Dlatent+decoder_dim) h = slim.fully_connected(inputs=hidden, activation_fn=tf.nn.leaky_relu, num_outputs=params.decoder_dim, scope='output_1') logits = slim.fully_connected( inputs=h, num_outputs=vocab_size, activation_fn=None, scope='output_2', weights_regularizer=weights_regularizer, biases_initializer=tf.initializers.constant( value=get_bias(smoothing=params.bias_smoothing), verify_shape=True)) # (N,L,V) return logits, penalty