Esempio n. 1
0
def discriminator_dag_supervised(latent,
                                 dag,
                                 dag_bw,
                                 dag_feats,
                                 sequence_length,
                                 params,
                                 idx,
                                 weights_regularizer=None,
                                 is_training=True):
    # latent (N, L, D)
    with tf.variable_scope('discriminator'):
        h = tf.concat([latent, dag_feats], axis=-1)
        with tf.variable_scope("upward"):
            h = message_passing(latent=h,
                                dag_bw=dag_bw,
                                params=params,
                                fully_connected_fn=sn_fully_connected,
                                weights_regularizer=weights_regularizer,
                                hidden_depth=params.discriminator_layers,
                                dim=params.discriminator_dim)
        with tf.variable_scope("downward"):
            h = message_passing(latent=h,
                                dag_bw=dag,
                                params=params,
                                fully_connected_fn=sn_fully_connected,
                                weights_regularizer=weights_regularizer,
                                hidden_depth=params.discriminator_layers,
                                dim=params.discriminator_dim)
        with tf.variable_scope('output_mlp'):
            if params.lstm_output_discriminator:
                h, _ = lstm(x=h,
                            num_units=params.decoder_dim,
                            bidirectional=True,
                            num_layers=params.decoder_layers,
                            sequence_lengths=sequence_length)
            else:
                for i in range(params.discriminator_layers):
                    h = sn_fully_connected(
                        inputs=h,
                        activation_fn=tf.nn.leaky_relu,
                        weights_regularizer=weights_regularizer,
                        num_outputs=params.discriminator_dim,
                        scope='discriminator_output_{}'.format(i))
            logits = sn_fully_connected(
                inputs=h,
                num_outputs=1,
                activation_fn=None,
                scope='discriminator_logits',
                weights_regularizer=weights_regularizer)  # (N,L,1)
            logits = tf.squeeze(logits, axis=-1)  # (N, L)
            logits_values = tf.gather_nd(params=logits, indices=idx)  # (X,)
            sparse_logits = tf.SparseTensor(values=logits_values,
                                            indices=tf.cast(idx, tf.int64),
                                            dense_shape=tf.cast(
                                                tf.shape(logits), tf.int64))
            sparse_logits = tf.sparse_reorder(sparse_logits)
            dis_values = tf.sparse_reduce_sum(
                sp_input=sparse_logits, axis=-1) / tf.cast(
                    sequence_length, tf.float32)  # (n,)
        return dis_values  # (n,)
Esempio n. 2
0
def discriminator_dag_supervised(
        latent, dag, dag_bw, params, idx,
        tags, tag_size, weights_regularizer=None,
        is_training=True):
    # latent (N, L, D)
    with tf.variable_scope('decoder'):
        with tf.variable_scope('tag_embedding'):
            tag_embeddings = sn_kernel(
                shape=[tag_size, params.decoder_dim],
                scope="tag_embeddings"
            )
        h_tags = tf.nn.embedding_lookup(params=tag_embeddings, ids=tags)
        h = tf.concat([latent, h_tags], axis=-1)
        with tf.variable_scope("forward"):
            h = message_passing(
                latent=h,
                dag_bw=dag_bw,
                params=params,
                fully_connected_fn=sn_fully_connected
            )
        with tf.variable_scope("backward"):
            h = message_passing(
                latent=h,
                dag_bw=dag,
                params=params,
                fully_connected_fn=sn_fully_connected
            )
        with tf.variable_scope('output_mlp'):
            h = sn_fully_connected(
                inputs=h,
                activation_fn=tf.nn.leaky_relu,
                num_outputs=params.decoder_dim,
                scope='output_1',
                weights_regularizer=weights_regularizer
            )
            logits = sn_fully_connected(
                inputs=h,
                num_outputs=1,
                activation_fn=None,
                scope='output_2',
                weights_regularizer=weights_regularizer
            )  # (N,L,1)
            logits = tf.squeeze(logits, axis=-1)  # (N, L)
            logits_values = tf.gather_nd(
                params=logits,
                indices=idx
            )  # (X,)
            sparse_logits = tf.SparseTensor(
                values=logits_values,
                indices=tf.cast(idx, tf.int64),
                dense_shape=tf.cast(tf.shape(logits), tf.int64)
            )
            sparse_logits = tf.sparse_reorder(sparse_logits)
            dis_values = tf.sparse_reduce_sum(
                sp_input=sparse_logits,
                axis=-1
            )  # (n,)
        return dis_values  # (n,)
def decoder_dag_supervised(latent,
                           dag,
                           dag_bw,
                           sequence_length,
                           dag_feats,
                           vocab_size,
                           params,
                           weights_regularizer=None,
                           is_training=True):
    # latent (N, L, D)
    with tf.variable_scope('decoder'):
        h = tf.concat([latent, dag_feats], axis=-1)
        with tf.variable_scope("upward"):
            h = message_passing(latent=h,
                                dag_bw=dag_bw,
                                params=params,
                                dim=params.decoder_dim,
                                hidden_depth=params.decoder_layers,
                                weights_regularizer=weights_regularizer)
        with tf.variable_scope("downward"):
            h = message_passing(latent=h,
                                dag_bw=dag,
                                params=params,
                                dim=params.decoder_dim,
                                hidden_depth=params.decoder_layers,
                                weights_regularizer=weights_regularizer)
        with tf.variable_scope('output_mlp'):
            if params.lstm_output:
                h, _ = lstm(x=h,
                            num_units=params.decoder_dim,
                            bidirectional=True,
                            num_layers=params.decoder_layers,
                            sequence_lengths=sequence_length)
            else:
                for i in range(params.decoder_layers):
                    h = slim.fully_connected(
                        inputs=h,
                        activation_fn=tf.nn.leaky_relu,
                        weights_regularizer=weights_regularizer,
                        num_outputs=params.decoder_dim,
                        scope='decoder_output_{}'.format(i))
            logits = slim.fully_connected(
                inputs=h,
                num_outputs=vocab_size,
                activation_fn=None,
                scope='decoder_output_logits',
                weights_regularizer=weights_regularizer,
                #biases_initializer=tf.initializers.constant(
                #    value=get_bias(smoothing=params.bias_smoothing),
                #    verify_shape=True)
            )  # (N,L,V)
        return logits
Esempio n. 4
0
def encoder_dag(dag,
                dag_bw,
                dag_feats,
                text,
                vocab_size,
                params,
                weights_regularizer=None):
    with tf.variable_scope("encoder"):
        text_embeddings = tf.get_variable(
            dtype=tf.float32,
            name="text_embeddings",
            shape=[vocab_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. /
                tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        h_text = tf.nn.embedding_lookup(params=text_embeddings,
                                        ids=text)  # (N, L, D)
        h = tf.concat([h_text, dag_feats], axis=-1)
        with tf.variable_scope("upward"):
            h = message_passing(latent=h,
                                dag_bw=dag_bw,
                                params=params,
                                dim=params.encoder_dim,
                                hidden_depth=params.encoder_layers,
                                weights_regularizer=weights_regularizer)
        with tf.variable_scope("downward"):
            h = message_passing(latent=h,
                                dag_bw=dag,
                                params=params,
                                dim=params.encoder_dim,
                                hidden_depth=params.encoder_layers,
                                weights_regularizer=weights_regularizer)

        for i in range(params.encoder_layers):
            h = slim.fully_connected(inputs=h,
                                     activation_fn=tf.nn.leaky_relu,
                                     num_outputs=params.encoder_dim,
                                     scope='encoder_output_{}'.format(i),
                                     weights_regularizer=weights_regularizer)

        mu = slim.fully_connected(inputs=h,
                                  num_outputs=params.latent_dim,
                                  activation_fn=None,
                                  scope='encoder_mlp_mu',
                                  weights_regularizer=weights_regularizer)
        logsigma = slim.fully_connected(
            inputs=h,
            num_outputs=params.latent_dim,
            activation_fn=None,
            scope='encoder_mlp_logsigma',
            weights_regularizer=weights_regularizer)
        return mu, logsigma
Esempio n. 5
0
def discriminator_dag(latent, dag, dag_bw, sequence_length, params, weights_regularizer=None,
                      is_training=True):
    # latent (N, L, D)
    N = tf.shape(latent)[0]
    L = tf.shape(latent)[1]
    with tf.variable_scope('discriminator'):
        h_linspace = tf.linspace(start=0, stop=L, num=L)
        h_linspace = tf.tile(tf.expand_dims(h_linspace, 0), [N, 1])
        h_linspace = h_linspace / tf.expand_dims(sequence_length, axis=1)
        h = tf.concat([latent, tf.expand_dims(h_linspace, -1)], axis=-1)
        with tf.variable_scope("upward"):
            h = message_passing(
                latent=h,
                dag_bw=dag_bw,
                params=params,
                dim=params.discriminator_dim,
                hidden_depth=params.discriminator_layers,
                weights_regularizer=weights_regularizer
            )
        with tf.variable_scope("downward"):
            h = message_passing(
                latent=h,
                dag_bw=dag,
                params=params,
                dim=params.discriminator_dim,
                hidden_depth=params.discriminator_layers,
                weights_regularizer=weights_regularizer
            )
        with tf.variable_scope('output_mlp'):
            for i in range(params.discriminator_layers):
                h = slim.fully_connected(
                    inputs=h,
                    activation_fn=tf.nn.leaky_relu,
                    weights_regularizer=weights_regularizer,
                    num_outputs=params.discriminator_dim,
                    scope='discriminator_output_{}'.format(i)
                )
            logits = slim.fully_connected(
                inputs=h,
                num_outputs=1,
                activation_fn=None,
                scope='discriminator_output_logits',
                weights_regularizer=weights_regularizer,
                biases_initializer=tf.initializers.constant(
                    value=get_bias(smoothing=params.bias_smoothing),
                    verify_shape=True)
            )  # (N,L,1)
        return logits
Esempio n. 6
0
def encoder_dag(dag,
                dag_bw,
                text,
                tags,
                vocab_size,
                tags_size,
                params,
                weights_regularizer=None):
    with tf.variable_scope("encoder"):
        text_embeddings = tf.get_variable(
            dtype=tf.float32,
            name="text_embeddings",
            shape=[vocab_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. /
                tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        tag_embeddings = tf.get_variable(
            dtype=tf.float32,
            name="tag_embeddings",
            shape=[tags_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. /
                tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        h_text = tf.nn.embedding_lookup(params=text_embeddings,
                                        ids=text)  # (L, N, D)
        h_tags = tf.nn.embedding_lookup(params=tag_embeddings,
                                        ids=tags)  # (L, N, D)
        h = tf.concat([h_text, h_tags], axis=-1)
        with tf.variable_scope("forward"):
            h = message_passing(latent=h, dag_bw=dag_bw, params=params)
        with tf.variable_scope("backward"):
            h = message_passing(latent=h, dag_bw=dag, params=params)

        mu = slim.fully_connected(inputs=h,
                                  num_outputs=params.latent_dim,
                                  activation_fn=None,
                                  scope='encoder_mlp_mu',
                                  weights_regularizer=weights_regularizer)
        logsigma = slim.fully_connected(
            inputs=h,
            num_outputs=params.latent_dim,
            activation_fn=None,
            scope='encoder_mlp_logsigma',
            weights_regularizer=weights_regularizer)
        return mu, logsigma
Esempio n. 7
0
def vae_decoder_dag_supervised(latent,
                               dag,
                               dag_bw,
                               vocab_size,
                               params,
                               tags,
                               tag_size,
                               weights_regularizer=None,
                               is_training=True):
    # latent (N, L, D)
    with tf.variable_scope('decoder'):
        tag_embeddings = tf.get_variable(
            dtype=tf.float32,
            name="tag_embeddings",
            shape=[tag_size, params.decoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. /
                tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        h_tags = tf.nn.embedding_lookup(params=tag_embeddings, ids=tags)
        h = tf.concat([latent, h_tags], axis=-1)
        with tf.variable_scope("forward"):
            h = message_passing(latent=h, dag_bw=dag_bw, params=params)
        with tf.variable_scope("backward"):
            h = message_passing(latent=h, dag_bw=dag, params=params)
        with tf.variable_scope('output_mlp'):
            h = slim.fully_connected(inputs=h,
                                     activation_fn=tf.nn.leaky_relu,
                                     num_outputs=params.decoder_dim,
                                     scope='output_1')
            logits = slim.fully_connected(
                inputs=h,
                num_outputs=vocab_size,
                activation_fn=None,
                scope='output_2',
                weights_regularizer=weights_regularizer,
                biases_initializer=tf.initializers.constant(
                    value=get_bias(smoothing=params.bias_smoothing),
                    verify_shape=True))  # (N,L,V)
        return logits
Esempio n. 8
0
def vae_decoder_dag(latent,
                    sequence_lengths,
                    vocab_size,
                    params,
                    n,
                    weights_regularizer=None,
                    is_training=True):
    # latent (N, L, D)
    with tf.variable_scope('decoder'):
        latent_processed = process_latent(latent=latent,
                                          sequence_lengths=sequence_lengths,
                                          params=params)

        dag, penalty = make_dag(latent=latent_processed,
                                sequence_lengths=sequence_lengths,
                                params=params)

        dag_bw = tf.transpose(dag, (0, 2, 1))
        hidden = message_passing(latent=latent, dag_bw=dag, params=params)
        hidden = message_passing(latent=hidden, dag_bw=dag_bw, params=params)
        # hidden (N, L, Dlatent+decoder_dim)

        h = slim.fully_connected(inputs=hidden,
                                 activation_fn=tf.nn.leaky_relu,
                                 num_outputs=params.decoder_dim,
                                 scope='output_1')
        logits = slim.fully_connected(
            inputs=h,
            num_outputs=vocab_size,
            activation_fn=None,
            scope='output_2',
            weights_regularizer=weights_regularizer,
            biases_initializer=tf.initializers.constant(
                value=get_bias(smoothing=params.bias_smoothing),
                verify_shape=True))  # (N,L,V)
        return logits, penalty