Пример #1
0
def vae_flat_decoder(latent, vocab_size, params, n, weights_regularizer=None):
    # latent (N, D)
    with tf.variable_scope('decoder'):
        depth = params.tree_depth
        assert depth >= 0
        h = slim.fully_connected(latent,
                                 num_outputs=params.decoder_dim,
                                 scope='projection',
                                 activation_fn=None,
                                 weights_regularizer=weights_regularizer)
        h = tf.expand_dims(h, axis=0)  # (1, N, D)
        h = tf.tile(h, (params.flat_length, 1, 1))  # (L,N,D)
        h, _ = lstm(x=h,
                    num_layers=3,
                    num_units=params.decoder_dim,
                    bidirectional=True)
        h = slim.fully_connected(inputs=h,
                                 num_outputs=params.encoder_dim,
                                 activation_fn=tf.nn.leaky_relu,
                                 scope='decoder_mlp_1',
                                 weights_regularizer=weights_regularizer)
        h = slim.fully_connected(inputs=h,
                                 num_outputs=params.encoder_dim,
                                 activation_fn=tf.nn.leaky_relu,
                                 scope='decoder_mlp_2',
                                 weights_regularizer=weights_regularizer)
        h = slim.fully_connected(
            inputs=h,
            num_outputs=vocab_size + 1,
            activation_fn=None,
            scope='decoder_mlp_3',
            weights_regularizer=weights_regularizer)  # (L,N,V+1)
        return h
Пример #2
0
def encoder_flat(tokens, token_lengths, vocab_size, params, n, weights_regularizer=None, is_training=True):
    with tf.variable_scope('encoder'):
        h = tf.transpose(tokens, (1, 0))  # (L,N)
        embeddings = tf.get_variable(
            dtype=tf.float32,
            name="embeddings",
            shape=[vocab_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        h = tf.nn.embedding_lookup(embeddings, h)  # (L, N, D)
        _, h = lstm(
            x=h,
            num_layers=params.encoder_layers,
            num_units=params.encoder_dim,
            bidirectional=True,
            sequence_lengths=token_lengths
        )
        print("h1: {}".format(h))
        # h = h[1]  # [-2:, :, :]  # (2, N, D)
        h = tf.concat(h, axis=-1)
        print("h2: {}".format(h))
        h = tf.transpose(h, (1, 0, 2))  # (N,2,D)
        print("h3: {}".format(h))
        h = tf.reshape(h, (n, h.shape[1].value * h.shape[2].value))  # (N, 2D)
        print("h4: {}".format(h))
        if params.batch_norm:
            h = slim.batch_norm(h, is_training=is_training)
        h = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='encoder_mlp_1',
            weights_regularizer=weights_regularizer
        )
        if params.batch_norm:
            h = slim.batch_norm(h, is_training=is_training)
        h = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='encoder_mlp_2',
            weights_regularizer=weights_regularizer
        )
        if params.batch_norm:
            h = slim.batch_norm(h, is_training=is_training)
        mu = slim.fully_connected(
            inputs=h,
            num_outputs=params.latent_dim,
            activation_fn=None,
            scope='encoder_mlp_mu',
            weights_regularizer=weights_regularizer
        )
        logsigma = slim.fully_connected(
            inputs=h,
            num_outputs=params.latent_dim,
            activation_fn=None,
            scope='encoder_mlp_logsigma',
            weights_regularizer=weights_regularizer
        )
        return mu, logsigma
Пример #3
0
def process_latent(latent, sequence_lengths, params):
    latent, _ = lstm(x=tf.transpose(latent, (1, 0, 2)),
                     num_layers=2,
                     num_units=params.decoder_dim,
                     bidirectional=True,
                     sequence_lengths=sequence_lengths)
    latent = tf.transpose(latent, (1, 0, 2))
    return latent
Пример #4
0
def vae_flat_encoder_simple(tokens,
                            token_lengths,
                            vocab_size,
                            params,
                            n,
                            weights_regularizer=None):
    """

    :param tokens: (N,L)
    :param token_lengths: (N,)
    :param vocab_size:
    :param params:
    :param n:
    :param weights_regularizer:
    :return:
    """
    L = tf.shape(tokens)[1]
    with tf.variable_scope('encoder'):
        embeddings = tf.get_variable(
            dtype=tf.float32,
            name="embeddings",
            shape=[vocab_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. /
                tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        embedded_tokens = tf.nn.embedding_lookup(params=embeddings,
                                                 ids=tf.transpose(
                                                     tokens,
                                                     (1, 0)))  # (L, N, D)
        ls = tf.linspace(start=tf.constant(0, dtype=tf.float32),
                         stop=tf.constant(1, dtype=tf.float32),
                         num=L)  # (L,)
        ls = tf.tile(tf.expand_dims(ls, 1), [1, n])  # (L,N)
        ls = ls * tf.cast(L, dtype=tf.float32) / tf.cast(
            tf.expand_dims(token_lengths, 0), dtype=tf.float32)
        ls = tf.expand_dims(ls, 2)  # ( L,N,1)
        h = tf.concat([embedded_tokens, ls], axis=-1)
        h, _ = lstm(x=h,
                    num_layers=3,
                    num_units=params.encoder_dim,
                    bidirectional=True,
                    sequence_lengths=token_lengths)
        mu = slim.fully_connected(inputs=h,
                                  num_outputs=params.latent_dim,
                                  activation_fn=None,
                                  scope='encoder_mlp_mu',
                                  weights_regularizer=weights_regularizer)
        logsigma = slim.fully_connected(
            inputs=h,
            num_outputs=params.latent_dim,
            activation_fn=None,
            scope='encoder_mlp_logsigma',
            weights_regularizer=weights_regularizer)
        return mu, logsigma
Пример #5
0
def vae_flat_decoder_attn(latent,
                          vocab_size,
                          params,
                          n,
                          weights_regularizer=None,
                          is_training=True):
    # latent (N, D)
    with tf.variable_scope('decoder'):
        """
        h = slim.fully_connected(
            latent,
            num_outputs=params.decoder_dim,
            scope='projection',
            activation_fn=None,
            weights_regularizer=weights_regularizer
        )
        """
        h = latent
        # h = sequence_norm(h)
        h = slim.batch_norm(h, is_training=is_training)
        h, _ = lstm(x=h,
                    num_layers=3,
                    num_units=params.decoder_dim,
                    bidirectional=True)
        # h = sequence_norm(h)
        h = slim.batch_norm(h, is_training=is_training)
        """
        h = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='decoder_mlp_1',
            weights_regularizer=weights_regularizer
        )
        h = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='decoder_mlp_2',
            weights_regularizer=weights_regularizer
        )
        """
        h = slim.fully_connected(
            inputs=h,
            num_outputs=vocab_size + 1,
            activation_fn=None,
            scope='decoder_mlp_3',
            weights_regularizer=weights_regularizer,
            biases_initializer=tf.initializers.constant(
                value=get_bias_ctc(average_output_length=params.flat_length,
                                   smoothing=params.bias_smoothing),
                verify_shape=True))  # (L,N,V+1)
        return h
Пример #6
0
def decoder_flat(latent, vocab_size, params, n, weights_regularizer=None):
    # latent (N, D)
    N = tf.shape(latent)[0]
    L = params.flat_length
    with tf.variable_scope('decoder'):
        h = tf.expand_dims(latent, axis=0)  # (1, N, D)
        h = tf.tile(h, (L, 1, 1))  # (L,N,D)
        h = tf.concat([h, linspace_feature(N=N, L=L)], axis=-1)
        h, _ = lstm(x=h,
                    num_layers=params.decoder_layers,
                    num_units=params.decoder_dim,
                    bidirectional=True)
        h = slim.fully_connected(
            inputs=h,
            num_outputs=vocab_size + 1,
            activation_fn=None,
            scope='decoder_mlp_output',
            weights_regularizer=weights_regularizer)  # (L,N,V+1)
        return h
Пример #7
0
def vae_flat_encoder_attn(tokens,
                          token_lengths,
                          vocab_size,
                          params,
                          n,
                          output_length,
                          weights_regularizer=None,
                          is_training=True):
    """

    :param tokens: (N,L)
    :param token_lengths: (N,)
    :param vocab_size:
    :param params:
    :param n:
    :param output_length:
    :param weights_regularizer:
    :return:
    """
    L = tf.shape(tokens)[1]
    with tf.variable_scope('encoder'):
        with tf.variable_scope('step_1'):
            h = tf.transpose(tokens, (1, 0))  # (L,N)
            embeddings = tf.get_variable(
                dtype=tf.float32,
                name="embeddings",
                shape=[vocab_size, params.encoder_dim],
                initializer=tf.initializers.truncated_normal(
                    stddev=1. /
                    tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32)))
            )
            h = tf.nn.embedding_lookup(embeddings, h)  # (L, N, D)
            ls = tf.linspace(start=tf.constant(0, dtype=tf.float32),
                             stop=tf.constant(1, dtype=tf.float32),
                             num=L)  # (L,)
            ls = tf.tile(tf.expand_dims(ls, 1), [1, n])  # (L,N)
            ls = ls * tf.cast(L, dtype=tf.float32) / tf.cast(
                tf.expand_dims(token_lengths, 0), dtype=tf.float32)
            ls = tf.expand_dims(ls, 2)  # ( L,N,1)
            h = tf.concat([h, ls], axis=-1)
            hidden_state, hidden_state_final = lstm(
                x=h,
                num_layers=2,
                num_units=params.encoder_dim,
                bidirectional=True,
                sequence_lengths=token_lengths)
            h = tf.concat(hidden_state_final,
                          axis=-1)  # (layers*directions, N, D)
            h = tf.transpose(h, (1, 0, 2))  # (N,layers*directions,D)
            h = tf.reshape(h, (n, h.shape[1].value *
                               h.shape[2].value))  # (N, layers*directions*D)
            h = slim.batch_norm(inputs=h, is_training=True)
            h = slim.fully_connected(inputs=h,
                                     num_outputs=params.encoder_dim,
                                     activation_fn=tf.nn.leaky_relu,
                                     scope='encoder_mlp_1',
                                     weights_regularizer=weights_regularizer)
            h = slim.batch_norm(inputs=h, is_training=True)
            """
            h = slim.fully_connected(
                inputs=h,
                num_outputs=params.encoder_dim,
                activation_fn=tf.nn.leaky_relu,
                scope='encoder_mlp_2',
                weights_regularizer=weights_regularizer
            )
            """
            flat_encoding = slim.fully_connected(
                inputs=h,
                num_outputs=params.encoder_dim,
                activation_fn=tf.nn.leaky_relu,
                scope='encoder_mlp_3',
                weights_regularizer=weights_regularizer)  # (N,D)
        with tf.variable_scope('step_2'):
            h = tf.expand_dims(flat_encoding, axis=0)  # (1, N, D)
            h = tf.tile(h, (output_length, 1, 1))  # (O,N,D)
            ls = tf.linspace(start=-1., stop=1.,
                             num=params.flat_length)  # (O,)
            ls = tf.tile(tf.expand_dims(tf.expand_dims(ls, 1), 2),
                         (1, n, 1))  # (O,N,1)
            h = tf.concat([h, ls], axis=2)
            output_hidden, _ = lstm(x=h,
                                    num_layers=2,
                                    num_units=params.encoder_dim,
                                    bidirectional=True)  # (O, N, D)
            # output_hidden = sequence_norm(output_hidden)
            output_hidden = slim.batch_norm(inputs=output_hidden,
                                            is_training=is_training)
        with tf.variable_scope('encoder_attn'):
            output_proj = slim.fully_connected(
                inputs=output_hidden,
                num_outputs=params.attention_dim,
                activation_fn=None,
                scope='encoder_output_proj',
                weights_regularizer=weights_regularizer)  # (O,N,D)
            input_proj = slim.fully_connected(
                inputs=hidden_state,
                num_outputs=params.attention_dim,
                activation_fn=None,
                scope='encoder_input_proj',
                weights_regularizer=weights_regularizer)  # (O,N,D)
            attn = calc_attn_v2(output_proj, input_proj,
                                token_lengths)  # (n, ol, il)
            tf.summary.image('encoder_attention', tf.expand_dims(attn, 3))
            input_aligned = tf.matmul(
                attn,  # (n, ol, il)
                tf.transpose(hidden_state, (1, 0, 2))  # (n, il, d)
            )  # (n, ol, d)
            h = tf.concat(
                [tf.transpose(input_aligned, (1, 0, 2)), output_hidden],
                axis=-1)
        with tf.variable_scope('encoder_output'):
            # h = sequence_norm(h)
            h = slim.batch_norm(h, is_training=is_training)
            h, _ = lstm(x=h,
                        num_layers=2,
                        num_units=params.encoder_dim,
                        bidirectional=True)  # (O, N, D)
            """
            h = slim.fully_connected(
                inputs=h,
                num_outputs=params.encoder_dim,
                activation_fn=None,
                scope='encoder_mlp_out_1',
                weights_regularizer=weights_regularizer
            )
            h = slim.fully_connected(
                inputs=h,
                num_outputs=params.encoder_dim,
                activation_fn=None,
                scope='encoder_mlp_out_2',
                weights_regularizer=weights_regularizer
            )
            """
            # h = sequence_norm(h)
            h = slim.batch_norm(h, is_training=is_training)
            mu = slim.fully_connected(inputs=h,
                                      num_outputs=params.latent_dim,
                                      activation_fn=None,
                                      scope='encoder_mlp_mu',
                                      weights_regularizer=weights_regularizer)
            logsigma = slim.fully_connected(
                inputs=h,
                num_outputs=params.latent_dim,
                activation_fn=None,
                scope='encoder_mlp_logsigma',
                weights_regularizer=weights_regularizer)
            return mu, logsigma
Пример #8
0
def encoder_flat(tokens,
                 token_lengths,
                 vocab_size,
                 params,
                 n,
                 embeddings,
                 weights_regularizer=None,
                 is_training=True):
    with tf.variable_scope('encoder'):
        N = tf.shape(tokens)[0]
        L = tf.shape(tokens)[1]
        h = tf.transpose(tokens, (1, 0))  # (L,N)
        """
        embeddings = tf.get_variable(
            dtype=tf.float32,
            name="embeddings",
            shape=[vocab_size, params.encoder_dim],
            initializer=tf.initializers.truncated_normal(
                stddev=1. / tf.sqrt(tf.constant(params.encoder_dim, dtype=tf.float32))))
        """
        inputs = [
            tf.nn.embedding_lookup(embeddings, h),  # (L, N, D)
            linspace_feature(L=L, N=N),
            linspace_scaled_feature(L=L, N=N, sequence_length=token_lengths)
        ]
        if params.model_mode == AAE_STOCH:
            noise = tf.random_normal(shape=(L, N, params.noise_dim),
                                     dtype=tf.float32)
            inputs.append(noise)
        h = tf.concat(inputs, axis=-1)
        _, h = lstm(x=h,
                    num_layers=params.encoder_layers,
                    num_units=params.encoder_dim,
                    bidirectional=True,
                    sequence_lengths=token_lengths)
        print("h1: {}".format(h))
        # h = h[1]  # [-2:, :, :]  # (2, N, D)
        h = tf.concat(h, axis=-1)
        print("h2: {}".format(h))
        h = tf.transpose(h, (1, 0, 2))  # (N,2,D)
        print("h3: {}".format(h))
        h = tf.reshape(h, (n, h.shape[1].value * h.shape[2].value))  # (N, 2D)
        print("h4: {}".format(h))
        if params.batch_norm:
            h = slim.batch_norm(h, is_training=is_training)
        """
        for i in range(params.encoder_layers):
            h = slim.fully_connected(
                inputs=h,
                num_outputs=params.encoder_dim,
                activation_fn=tf.nn.leaky_relu,
                scope='encoder_mlp_{}'.format(i),
                weights_regularizer=weights_regularizer
            )
            if params.batch_norm:
                h = slim.batch_norm(h, is_training=is_training)
        """
        if params.model_mode == AAE_STOCH or params.model_mode == AE:
            encoding = slim.fully_connected(
                inputs=h,
                num_outputs=params.latent_dim,
                activation_fn=None,
                scope='encoder_mlp_encoding',
                weights_regularizer=weights_regularizer)
            return encoding, None
        else:
            mu = slim.fully_connected(inputs=h,
                                      num_outputs=params.latent_dim,
                                      activation_fn=None,
                                      scope='encoder_mlp_mu',
                                      weights_regularizer=weights_regularizer)
            logsigma = slim.fully_connected(
                inputs=h,
                num_outputs=params.latent_dim,
                activation_fn=None,
                scope='encoder_mlp_logsigma',
                weights_regularizer=weights_regularizer)
            return mu, logsigma
def encoder_bintree_attn_base(inputs,
                              token_lengths,
                              params,
                              weights_regularizer=None,
                              is_training=True):
    """

    :param inputs: (L,N)
    :param token_lengths:
    :return:
    """
    n = tf.shape(inputs)[1]
    with tf.variable_scope('input_lstm'):
        h = inputs
        hidden_state, hidden_state_final = lstm(x=h,
                                                num_layers=3,
                                                num_units=params.encoder_dim,
                                                bidirectional=True,
                                                sequence_lengths=token_lengths)
        h = tf.concat(hidden_state_final, axis=-1)  # (layers*directions, N, D)
        h = tf.transpose(h, (1, 0, 2))  # (N,layers*directions,D)
        h = tf.reshape(h, (n, h.shape[1].value *
                           h.shape[2].value))  # (N, layers*directions*D)
        """
        if params.batch_norm:
            h = slim.batch_norm(inputs=h, is_training=is_training)
        h = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='encoder_mlp_1',
            weights_regularizer=weights_regularizer
        )
        """
        if params.batch_norm:
            h = slim.batch_norm(inputs=h, is_training=is_training)
        flat_encoding = slim.fully_connected(
            inputs=h,
            num_outputs=params.encoder_dim,
            activation_fn=tf.nn.leaky_relu,
            scope='encoder_mlp_3',
            weights_regularizer=weights_regularizer)  # (N,D)
    with tf.variable_scope('bintree_attention'):
        #todo: recurrent attention
        output_tree = binary_tree_down(x0=flat_encoding,
                                       hidden_dim=params.encoder_dim,
                                       depth=params.tree_depth)
        output_projs = [
            slim.fully_connected(inputs=enc,
                                 num_outputs=params.attention_dim,
                                 activation_fn=None,
                                 scope='encoder_mlp_encoding',
                                 weights_regularizer=weights_regularizer,
                                 reuse=i > 0)
            for i, enc in enumerate(output_tree)
        ]
        input_proj = slim.fully_connected(
            inputs=hidden_state,
            num_outputs=params.attention_dim,
            activation_fn=None,
            scope='encoder_input_proj',
            weights_regularizer=weights_regularizer)  # (O,N,D)
        attns = [
            calc_attn_v2(output_proj,
                         input_proj,
                         token_lengths,
                         a_transpose=False,
                         b_transpose=True) for output_proj in output_projs
        ]  # (n, ol, il)

        attn_idx = infix_indices(params.tree_depth)
        flat_attns = stack_tree(attns, indices=attn_idx)  # (L,N,V)
        attn_img = tf.expand_dims(tf.transpose(flat_attns, (1, 0, 2)), axis=3)
        tf.summary.image('encoder_attention', attn_img)

        hs = [
            tf.matmul(
                attn,  # (n, ol, il)
                tf.transpose(hidden_state, (1, 0, 2)))  # (n, il, d)
            for attn in attns
        ]
        # (n, ol, d)
        hs = [tf.concat(cols, axis=-1) for cols in zip(hs, output_tree)]
    with tf.variable_scope('encoder_bintree_up'):
        messages_up = binary_tree_up(hidden_dim=params.encoder_dim, inputs=hs)
    with tf.variable_scope('encoder_bintree_down'):
        hs = [tf.concat(cols, axis=-1) for cols in zip(hs, messages_up)]
        messages_down = binary_tree_down(x0=tf.squeeze(hs[0], axis=1),
                                         hidden_dim=params.encoder_dim,
                                         depth=params.tree_depth,
                                         inputs=hs)
        hs = [tf.concat(cols, axis=-1) for cols in zip(hs, messages_down)]
    return hs