Exemplo n.º 1
0
def bidirectional_encoder(src,
                          src_length,
                          hidden_dim,
                          num_layer,
                          dropout_keep,
                          swap_memory=False,
                          use_dropout=False,
                          reuse=None,
                          name=None):
    with tf.variable_scope(name,
                           'encoder',
                           values=[src, src_length],
                           reuse=reuse):

        def create_rnn_layer(layer_num, dim):
            cell = tf_rnn.LSTMCell(dim, name='layer_%s' % layer_num)
            if use_dropout and dropout_keep < 1.:
                cell = tf_rnn.DropoutWrapper(cell,
                                             output_keep_prob=dropout_keep)

            if layer_num > 0:
                cell = tf_rnn.ResidualWrapper(cell)

            return cell

        batch_size = tf.shape(src)[0]

        fw_cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i, hidden_dim // 2) for i in range(num_layer)])
        bw_cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i, hidden_dim // 2) for i in range(num_layer)])

        fw_zero_state = sequence.create_trainable_initial_states(
            batch_size, fw_cell, 'fw_zs')
        bw_zero_state = sequence.create_trainable_initial_states(
            batch_size, bw_cell, 'bw_zs')

        outputs, state = tf.nn.bidirectional_dynamic_rnn(
            fw_cell,
            bw_cell,
            src,
            src_length,
            fw_zero_state,
            bw_zero_state,
            swap_memory=swap_memory)

        output = tf.concat(outputs, axis=2)
        final_state = sequence.last_relevant(output, src_length)

    return output, final_state
def context_encoder(words,
                    lengths,
                    hidden_dim,
                    num_layers,
                    name=None,
                    reuse=None):
    """
    Args:
        words: tensor in shape of [batch x max_len x embed_dim]
        lengths: tensor in shape of [batch]
        hidden_dim: num of lstm hidden units
        name: op name
        reuse: reuse variable

    Returns:
        aggregation_result, a tensor in shape of [batch x max_len x hidden_dim]

    """
    with tf.variable_scope(name,
                           'context_encoder', [words, lengths],
                           reuse=reuse):

        def create_rnn_layer(layer_num, dim):
            if layer_num == 0:
                return tf_rnn.LSTMCell(dim, name='layer_%s' % layer_num)

            cell = tf_rnn.LSTMCell(dim, name='layer_%s' % layer_num)
            cell = tf_rnn.ResidualWrapper(cell)
            return cell

        batch_size = tf.shape(words)[0]

        fw_cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i, hidden_dim // 2) for i in range(num_layers)])
        bw_cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i, hidden_dim // 2) for i in range(num_layers)])

        fw_zero_state = sequence.create_trainable_initial_states(
            batch_size, fw_cell, 'fw_zs')
        bw_zero_state = sequence.create_trainable_initial_states(
            batch_size, bw_cell, 'bw_zs')

        outputs, state = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell, words, lengths, fw_zero_state, bw_zero_state)

        output = tf.concat(outputs, axis=2)
        assert output.shape[2] == hidden_dim

    return output
def word_aggregator(words,
                    lengths,
                    hidden_dim,
                    num_layers,
                    swap_memory=False,
                    use_dropout=False,
                    dropout_keep=1.0,
                    name=None,
                    reuse=None):
    """
    Args:
        words: tensor in shape of [batch x max_len x embed_dim]
        lengths: tensor in shape of [batch]
        hidden_dim: num of lstm hidden units
        name: op name
        reuse: reuse variable

    Returns:
        aggregation_result, a tensor in shape of [batch x max_len x hidden_dim]

    """
    with tf.variable_scope(name,
                           'word_aggregator', [words, lengths],
                           reuse=reuse):
        batch_size = tf.shape(words)[0]

        def create_rnn_layer(layer_num, dim):
            cell = tf_rnn.LSTMCell(dim, name='layer_%s' % layer_num)
            if use_dropout and dropout_keep < 1.:
                cell = tf_rnn.DropoutWrapper(cell,
                                             output_keep_prob=dropout_keep)

            if layer_num > 0:
                cell = tf_rnn.ResidualWrapper(cell)

            return cell

        cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i, hidden_dim) for i in range(num_layers)])
        zero_state = sequence.create_trainable_initial_states(batch_size, cell)

        outputs, last_state = tf.nn.dynamic_rnn(cell,
                                                words,
                                                sequence_length=lengths,
                                                initial_state=zero_state,
                                                swap_memory=swap_memory)

        return outputs
Exemplo n.º 4
0
def residual_decoder(agenda,
                     dec_inputs,
                     dec_input_lengths,
                     hidden_dim,
                     num_layer,
                     swap_memory,
                     enable_dropout=False,
                     dropout_keep=1.,
                     name=None):
    with tf.variable_scope(name, 'residual_decoder', []):
        batch_size = tf.shape(dec_inputs)[0]
        embeddings = vocab.get_embeddings()

        # Concatenate agenda [y_hat;base_input_embed] with decoder inputs

        # [batch x max_len x word_dim]
        dec_inputs = tf.nn.embedding_lookup(embeddings, dec_inputs)
        max_len = tf.shape(dec_inputs)[1]

        # [batch x 1 x agenda_dim]
        agenda = tf.expand_dims(agenda, axis=1)

        # [batch x max_len x agenda_dim]
        agenda = tf.tile(agenda, [1, max_len, 1])

        # [batch x max_len x word_dim+agenda_dim]
        dec_inputs = tf.concat([dec_inputs, agenda], axis=2)

        helper = seq2seq.TrainingHelper(dec_inputs,
                                        dec_input_lengths,
                                        name='train_helper')
        cell = tf_rnn.MultiRNNCell([
            create_rnn_layer(i, hidden_dim // 2, enable_dropout, dropout_keep)
            for i in range(num_layer)
        ])
        zero_states = create_trainable_initial_states(batch_size, cell)

        output_layer = DecoderOutputLayer(embeddings)
        decoder = seq2seq.BasicDecoder(cell, helper, zero_states, output_layer)

        outputs, state, length = seq2seq.dynamic_decode(
            decoder, swap_memory=swap_memory)

        return outputs, state, length
def word_aggregator(words,
                    lengths,
                    hidden_dim,
                    num_layers,
                    name=None,
                    reuse=None):
    """
    Args:
        words: tensor in shape of [batch x max_len x embed_dim]
        lengths: tensor in shape of [batch]
        hidden_dim: num of lstm hidden units
        name: op name
        reuse: reuse variable

    Returns:
        aggregation_result, a tensor in shape of [batch x max_len x hidden_dim]

    """
    with tf.variable_scope(name,
                           'word_aggregator', [words, lengths],
                           reuse=reuse):
        batch_size = tf.shape(words)[0]

        def create_rnn_layer(layer_num):
            if layer_num == 0:
                return tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % layer_num)

            cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % layer_num)
            cell = tf_rnn.ResidualWrapper(cell)
            return cell

        cell = tf_rnn.MultiRNNCell(
            [create_rnn_layer(i) for i in range(num_layers)])
        zero_state = sequence.create_trainable_initial_states(batch_size, cell)

        outputs, last_state = tf.nn.dynamic_rnn(cell,
                                                words,
                                                sequence_length=lengths,
                                                initial_state=zero_state)

        return outputs