Exemplo n.º 1
0
        def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights,
                       prev_weights):
            context_vector, new_weights = attention_(state, prev_weights=prev_weights)
            weights = weights.write(time, new_weights)

            # FIXME use `output` or `state` here?
            output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout')
            output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2)
            output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0')
            decoder_outputs = decoder_outputs.write(time, output_)
            output_ = linear_unsafe(output_, output_size, True, scope='softmax1')
            proj_outputs = proj_outputs.write(time, output_)

            argmax = lambda: tf.argmax(output_, 1)
            softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1),
                                         axis=1)
            target = lambda: inputs.read(time + 1)

            sample = tf.case([
                (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target),
                (tf.logical_not(feed_argmax), softmax)],
                default=argmax)   # default case is useful for beam-search

            sample.set_shape([None])
            sample = tf.stop_gradient(sample)

            samples = samples.write(time, sample)
            input_ = embed(sample)

            x = tf.concat([input_, context_vector], 1)
            call_cell = lambda: unsafe_decorator(cell)(x, state)

            if sequence_length is not None:
                new_output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                new_output, new_state = call_cell()

            states = states.write(time, new_state)

            return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights,
                    new_weights)
Exemplo n.º 2
0
def compute_energy_with_filter(hidden, state, prev_weights, attention_filters, attention_filter_length,
                               **kwargs):
    time_steps = tf.shape(hidden)[1]
    attn_size = hidden.get_shape()[3].value
    batch_size = tf.shape(hidden)[0]

    filter_shape = [attention_filter_length * 2 + 1, 1, 1, attention_filters]
    filter_ = get_variable_unsafe('filter', filter_shape)
    u = get_variable_unsafe('U', [attention_filters, attn_size])
    prev_weights = tf.reshape(prev_weights, tf.stack([batch_size, time_steps, 1, 1]))
    conv = tf.nn.conv2d(prev_weights, filter_, [1, 1, 1, 1], 'SAME')
    shape = tf.stack([tf.multiply(batch_size, time_steps), attention_filters])
    conv = tf.reshape(conv, shape)
    z = tf.matmul(conv, u)
    z = tf.reshape(z, tf.stack([batch_size, time_steps, 1, attn_size]))

    y = linear_unsafe(state, attn_size, True)
    y = tf.reshape(y, [-1, 1, 1, attn_size])

    k = get_variable_unsafe('W', [attn_size, attn_size])

    # dot product between tensors requires reshaping
    hidden = tf.reshape(hidden, tf.stack([tf.multiply(batch_size, time_steps), attn_size]))
    f = tf.matmul(hidden, k)
    f = tf.reshape(f, tf.stack([batch_size, time_steps, 1, attn_size]))

    v = get_variable_unsafe('V', [attn_size])
    s = f + y + z
    return tf.reduce_sum(v * tf.tanh(s), [2, 3])
Exemplo n.º 3
0
def compute_energy(hidden, state, attn_size, **kwargs):
    input_size = hidden.get_shape()[3].value
    batch_size = tf.shape(hidden)[0]
    time_steps = tf.shape(hidden)[1]

    # initializer = tf.random_normal_initializer(stddev=0.001)   # same as Bahdanau et al.
    initializer = None
    y = linear_unsafe(state,
                      attn_size,
                      True,
                      scope='W_a',
                      initializer=initializer)
    y = tf.reshape(y, [-1, 1, attn_size])

    k = get_variable_unsafe('U_a', [input_size, attn_size],
                            initializer=initializer)

    # dot product between tensors requires reshaping
    hidden = tf.reshape(
        hidden, tf.stack([tf.multiply(batch_size, time_steps), input_size]))
    f = tf.matmul(hidden, k)
    f = tf.reshape(f, tf.stack([batch_size, time_steps, attn_size]))

    v = get_variable_unsafe('v_a', [attn_size])
    s = f + y

    return tf.reduce_sum(v * tf.tanh(s), [2])
Exemplo n.º 4
0
def compute_energy(hidden, state, name, **kwargs):
    attn_size = hidden.get_shape()[3].value
    batch_size = tf.shape(hidden)[0]
    time_steps = tf.shape(hidden)[1]

    y = linear_unsafe(state, attn_size, True, scope=name)
    y = tf.reshape(y, [-1, 1, 1, attn_size])

    k = get_variable_unsafe('W_{}'.format(name), [attn_size, attn_size])

    # dot product between tensors requires reshaping
    hidden = tf.reshape(hidden, tf.pack([tf.mul(batch_size, time_steps), attn_size]))
    f = tf.matmul(hidden, k)
    f = tf.reshape(f, tf.pack([batch_size, time_steps, 1, attn_size]))

    v = get_variable_unsafe('V_{}'.format(name), [attn_size])
    s = f + y

    return tf.reduce_sum(v * tf.tanh(s), [2, 3])
Exemplo n.º 5
0
        def _time_step(time, state, _, attn_weights, output_ta_t, attn_weights_ta_t):
            input_t = input_ta.read(time)
            # restore some shape information
            r = tf.random_uniform([])
            input_t = tf.cond(tf.logical_and(time > 0, r < feed_previous),
                              lambda: tf.stop_gradient(extract_argmax_and_embed(output_ta_t.read(time - 1))),
                              lambda: input_t)
            input_t.set_shape(decoder_inputs.get_shape()[1:])
            # the code from TensorFlow used a concatenation of input_t and attns as input here
            # TODO: evaluate the impact of this
            call_cell = lambda: unsafe_decorator(cell)(input_t, state)

            if sequence_length is not None:
                output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                output, new_state = call_cell()

            attn_weights_ta_t = attn_weights_ta_t.write(time, attn_weights)
            # using decoder state instead of decoder output in the attention model seems
            # to give much better results
            new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights)

            with tf.variable_scope('attention_output_projection'):  # this can take a lot of memory
                output = linear_unsafe([output, new_attns], output_size, True)

            output_ta_t = output_ta_t.write(time, output)
            return time + 1, new_state, new_attns, new_attn_weights, output_ta_t, attn_weights_ta_t
Exemplo n.º 6
0
def attention_decoder(targets, initial_state, attention_states, encoders, decoder, encoder_input_length,
                      decoder_input_length=None, dropout=None, feed_previous=0.0, feed_argmax=True,
                      **kwargs):
    """
    :param targets: tensor of shape (output_length, batch_size)
    :param initial_state: initial state of the decoder (usually the final state of the encoder),
      as a tensor of shape (batch_size, initial_state_size). This state is mapped to the
      correct state size for the decoder.
    :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size),
      usually the encoder outputs (one tensor for each encoder).
    :param encoders: configuration of the encoders
    :param decoder: configuration of the decoder
    :param decoder_input_length:
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output
      instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training)
    :return:
      outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size)
      attention weights as a tensor of shape (output_length, encoders, batch_size, input_length)
    """
    # TODO: dropout instead of keep probability
    assert decoder.cell_size % 2 == 0, 'cell size must be a multiple of 2'   # because of maxout

    decoder_inputs = targets[:-1,:]  # starts with BOS

    if decoder.get('embedding') is not None:
        initializer = decoder.embedding
        embedding_shape = None
    else:
        initializer = None
        embedding_shape = [decoder.vocab_size, decoder.embedding_size]

    with tf.device('/cpu:0'):
        embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape,
                                        initializer=initializer)

    if decoder.use_lstm:
        cell = BasicLSTMCell(decoder.cell_size, state_is_tuple=False)
    else:
        cell = GRUCell(decoder.cell_size, initializer=orthogonal_initializer())

    if dropout is not None:
        cell = DropoutWrapper(cell, input_keep_prob=dropout)

    if decoder.layers > 1:
        cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections)

    with tf.variable_scope('decoder_{}'.format(decoder.name)):
        def embed(input_):
            if embedding is not None:
                return tf.nn.embedding_lookup(embedding, input_)
            else:
                return input_

        hidden_states = [tf.expand_dims(states, 2) for states in attention_states]
        attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders,
                                       encoder_input_length=encoder_input_length)

        input_shape = tf.shape(decoder_inputs)
        time_steps = input_shape[0]
        batch_size = input_shape[1]
        output_size = decoder.vocab_size
        state_size = cell.state_size

        if initial_state is not None:
            if dropout is not None:
                initial_state = tf.nn.dropout(initial_state, dropout)

            state = tf.nn.tanh(
                linear_unsafe(initial_state, state_size, True, scope='initial_state_projection')
            )
        else:
            # if not initial state, initialize with zeroes (this is the case for MIXER)
            state = tf.zeros([batch_size, state_size], dtype=tf.float32)

        sequence_length = decoder_input_length
        if sequence_length is not None:
            sequence_length = tf.to_int32(sequence_length)
            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

        time = tf.constant(0, dtype=tf.int32, name='time')
        zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), tf.float32)

        proj_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False)
        decoder_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps)

        inputs = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False).unstack(
                                tf.cast(decoder_inputs, tf.int64))
        samples = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False)
        states = tf.TensorArray(dtype=tf.float32, size=time_steps)

        attn_lengths = [tf.shape(states)[1] for states in attention_states]

        weights = tf.TensorArray(dtype=tf.float32, size=time_steps)
        initial_weights = [tf.zeros(tf.stack([batch_size, length])) for length in attn_lengths]

        output = tf.zeros(tf.stack([batch_size, cell.output_size]), dtype=tf.float32)

        initial_input = embed(inputs.read(0))   # first symbol is BOS

        def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights,
                       prev_weights):
            context_vector, new_weights = attention_(state, prev_weights=prev_weights)
            weights = weights.write(time, new_weights)

            # FIXME use `output` or `state` here?
            output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout')
            output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2)
            output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0')
            decoder_outputs = decoder_outputs.write(time, output_)
            output_ = linear_unsafe(output_, output_size, True, scope='softmax1')
            proj_outputs = proj_outputs.write(time, output_)

            argmax = lambda: tf.argmax(output_, 1)
            softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1),
                                         axis=1)
            target = lambda: inputs.read(time + 1)

            sample = tf.case([
                (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target),
                (tf.logical_not(feed_argmax), softmax)],
                default=argmax)   # default case is useful for beam-search

            sample.set_shape([None])
            sample = tf.stop_gradient(sample)

            samples = samples.write(time, sample)
            input_ = embed(sample)

            x = tf.concat([input_, context_vector], 1)
            call_cell = lambda: unsafe_decorator(cell)(x, state)

            if sequence_length is not None:
                new_output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                new_output, new_state = call_cell()

            states = states.write(time, new_state)

            return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights,
                    new_weights)

        _, _, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, _ = tf.while_loop(
            cond=lambda time, *_: time < time_steps,
            body=_time_step,
            loop_vars=(time, initial_input, state, output, proj_outputs, decoder_outputs, samples, weights, states,
                       initial_weights),
            parallel_iterations=decoder.parallel_iterations,
            swap_memory=decoder.swap_memory)

        proj_outputs = proj_outputs.stack()
        decoder_outputs = decoder_outputs.stack()
        samples = samples.stack()
        weights = weights.stack()  # batch_size, encoders, output time, input time
        states = states.stack()

        # weights = tf.Print(weights, [weights[:,0]], summarize=20)
        # tf.control_dependencies()

        beam_tensors = namedtuple('beam_tensors', 'state new_state output new_output')
        return (proj_outputs, weights, decoder_outputs, beam_tensors(state, new_state, output, new_output),
                samples, states)
Exemplo n.º 7
0
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs):
    """
    Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`.
    The result is a list of the outputs produced by those encoders (for each time-step), and their final state.

    :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder)
    :param encoders: list of encoder configurations
    :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder)
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :return:
      encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size)
      encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes)
    """
    assert len(encoder_inputs) == len(encoders)
    encoder_states = []
    encoder_outputs = []

    # create embeddings in the global scope (allows sharing between encoder and decoder)
    embedding_variables = []
    for encoder in encoders:
        # inputs are token ids, which need to be mapped to vectors (embeddings)
        if not encoder.binary:
            if encoder.get('embedding') is not None:
                initializer = encoder.embedding
                embedding_shape = None
            else:
                # initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3))
                initializer = None
                embedding_shape = [encoder.vocab_size, encoder.embedding_size]

            with tf.device('/cpu:0'):
                embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape,
                                                initializer=initializer)
            embedding_variables.append(embedding)
        else:  # do nothing: inputs are already vectors
            embedding_variables.append(None)

    for i, encoder in enumerate(encoders):
        with tf.variable_scope('encoder_{}'.format(encoder.name)):
            encoder_inputs_ = encoder_inputs[i]
            encoder_input_length_ = encoder_input_length[i]

            # TODO: use state_is_tuple=True
            if encoder.use_lstm:
                cell = BasicLSTMCell(encoder.cell_size, state_is_tuple=False)
            else:
                cell = GRUCell(encoder.cell_size, initializer=orthogonal_initializer())

            if dropout is not None:
                cell = DropoutWrapper(cell, input_keep_prob=dropout)

            embedding = embedding_variables[i]

            if embedding is not None or encoder.input_layers:
                batch_size = tf.shape(encoder_inputs_)[0]  # TODO: fix this time major stuff
                time_steps = tf.shape(encoder_inputs_)[1]

                if embedding is None:
                    size = encoder_inputs_.get_shape()[2].value
                    flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps), size])
                else:
                    flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps)])
                    flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs)

                if encoder.input_layers:
                    for j, size in enumerate(encoder.input_layers):
                        name = 'input_layer_{}'.format(j)
                        flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name))
                        if dropout is not None:
                            flat_inputs = tf.nn.dropout(flat_inputs, dropout)

                encoder_inputs_ = tf.reshape(flat_inputs,
                                             tf.stack([batch_size, time_steps, flat_inputs.get_shape()[1].value]))

            # Contrary to Theano's RNN implementation, states after the sequence length are zero
            # (while Theano repeats last state)
            sequence_length = encoder_input_length_   # TODO
            parameters = dict(
                inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling,
                pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory,
                parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections,
                trainable_initial_state=True
            )

            if encoder.bidir:
                encoder_outputs_, _, _ = multi_bidirectional_rnn_unsafe(
                    cells=[(cell, cell)] * encoder.layers, **parameters)
                # Like Bahdanau et al., we use the first annotation h_1 of the backward encoder
                encoder_state_ = encoder_outputs_[:, 0, encoder.cell_size:]
                # TODO: if multiple layers, combine last states with a Maxout layer
            else:
                encoder_outputs_, encoder_state_ = multi_rnn_unsafe(
                    cells=[cell] * encoder.layers, **parameters)
                encoder_state_ = encoder_outputs_[:, -1, :]

            encoder_outputs.append(encoder_outputs_)
            encoder_states.append(encoder_state_)

    encoder_state = tf.concat(encoder_states, 1)
    return encoder_outputs, encoder_state
Exemplo n.º 8
0
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs):
    """
    Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`.
    The result is a list of the outputs produced by those encoders (for each time-step), and their final state.

    :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder)
    :param encoders: list of encoder configurations
    :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder)
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :return:
      encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size)
      encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes)
    """
    assert len(encoder_inputs) == len(encoders)
    encoder_states = []
    encoder_outputs = []

    # create embeddings in the global scope (allows sharing between encoder and decoder)
    embedding_variables = []
    for encoder in encoders:
        # inputs are token ids, which need to be mapped to vectors (embeddings)
        if not encoder.binary:
            if encoder.get('embedding') is not None:
                initializer = encoder.embedding
                embedding_shape = None
            else:
                initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3))
                embedding_shape = [encoder.vocab_size, encoder.embedding_size]

            with tf.device('/cpu:0'):
                embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape,
                                                initializer=initializer)
            embedding_variables.append(embedding)
        else:  # do nothing: inputs are already vectors
            embedding_variables.append(None)

    with tf.variable_scope('multi_encoder'):
        for i, encoder in enumerate(encoders):
            with tf.variable_scope(encoder.name):
                encoder_inputs_ = encoder_inputs[i]
                encoder_input_length_ = encoder_input_length[i]

                # TODO: use state_is_tuple=True
                if encoder.use_lstm:
                    cell = rnn_cell.BasicLSTMCell(encoder.cell_size, state_is_tuple=False)
                else:
                    cell = rnn_cell.GRUCell(encoder.cell_size)

                if dropout is not None:
                    cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout)

                embedding = embedding_variables[i]

                if embedding is not None or encoder.input_layers:
                    batch_size = tf.shape(encoder_inputs_)[0]  # TODO: fix this time major stuff
                    time_steps = tf.shape(encoder_inputs_)[1]

                    if embedding is None:
                        size = encoder_inputs_.get_shape()[2].value
                        flat_inputs = tf.reshape(encoder_inputs_, [tf.mul(batch_size, time_steps), size])
                    else:
                        flat_inputs = tf.reshape(encoder_inputs_, [tf.mul(batch_size, time_steps)])
                        flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs)

                    if encoder.input_layers:
                        for j, size in enumerate(encoder.input_layers):
                            name = 'input_layer_{}'.format(j)
                            flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name))
                            if dropout is not None:
                                flat_inputs = tf.nn.dropout(flat_inputs, dropout)

                    encoder_inputs_ = tf.reshape(flat_inputs,
                                                 tf.pack([batch_size, time_steps, flat_inputs.get_shape()[1].value]))

                sequence_length = encoder_input_length_
                parameters = dict(
                    inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling,
                    pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory,
                    parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections
                )

                if encoder.bidir:
                    encoder_outputs_, _, encoder_state_ = multi_bidirectional_rnn_unsafe(
                        cells=[(cell, cell)] * encoder.layers, **parameters)
                else:
                    encoder_outputs_, encoder_state_ = multi_rnn_unsafe(
                        cells=[cell] * encoder.layers, **parameters)

                if encoder.bidir:  # map to correct output dimension
                    # there is no tensor product operation, so we need to flatten our tensor to
                    # a matrix to perform a dot product
                    shape = tf.shape(encoder_outputs_)
                    batch_size = shape[0]
                    time_steps = shape[1]
                    dim = encoder_outputs_.get_shape()[2]
                    outputs_ = tf.reshape(encoder_outputs_, tf.pack([tf.mul(batch_size, time_steps), dim]))
                    outputs_ = linear_unsafe(outputs_, cell.output_size, False, scope='bidir_projection')
                    encoder_outputs_ = tf.reshape(outputs_, tf.pack([batch_size, time_steps, cell.output_size]))

                encoder_outputs.append(encoder_outputs_)
                encoder_states.append(encoder_state_)

        encoder_state = tf.concat(1, encoder_states)
        return encoder_outputs, encoder_state
Exemplo n.º 9
0
def beam_search_decoder(decoder_input, initial_state, attention_states, encoders, decoder, output_projection=None,
                        dropout=None, **kwargs):
    """
    Same as `attention_decoder`, except that it only performs one step of the decoder.

    :param decoder_input: tensor of size (batch_size), corresponding to the previous output of the decoder
    :return:
      current output of the decoder
      tuple of (state, new_state, attn_weights, new_attn_weights, attns, new_attns)
    """
    # TODO: code refactoring with `attention_decoder`
    if decoder.get('embedding') is not None:
        embedding_initializer = decoder.embedding
        embedding_shape = None
    else:
        embedding_initializer = None
        embedding_shape = [decoder.vocab_size, decoder.embedding_size]

    with tf.device('/cpu:0'):
        embedding = get_variable_unsafe('embedding_{}'.format(decoder.name),
                                        shape=embedding_shape,
                                        initializer=embedding_initializer)
    if decoder.use_lstm:
        cell = rnn_cell.BasicLSTMCell(decoder.cell_size, state_is_tuple=False)
    else:
        cell = rnn_cell.GRUCell(decoder.cell_size)

    if dropout is not None:
        cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout)

    if decoder.layers > 1:
        cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections)

    if output_projection is None:
        output_size = decoder.vocab_size
    else:
        output_size = cell.output_size
        proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32)
        proj_weights.get_shape().assert_is_compatible_with([cell.output_size, decoder.vocab_size])
        proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32)
        proj_biases.get_shape().assert_is_compatible_with([decoder.vocab_size])

    with tf.variable_scope('decoder_{}'.format(decoder.name)):
        decoder_input = tf.nn.embedding_lookup(embedding, decoder_input)

        attn_lengths = [tf.shape(states)[1] for states in attention_states]
        attn_size = sum(states.get_shape()[2].value for states in attention_states)
        hidden_states = [tf.expand_dims(states, 2) for states in attention_states]
        attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders)

        if dropout is not None:
            initial_state = tf.nn.dropout(initial_state, dropout)
        state = tf.nn.tanh(
            linear_unsafe(initial_state, cell.state_size, False, scope='initial_state_projection')
        )

        batch_size = tf.shape(decoder_input)[0]
        attn_weights = [tf.zeros(tf.pack([batch_size, length])) for length in attn_lengths]

        attns = tf.zeros(tf.pack([batch_size, attn_size]), dtype=tf.float32)

        cell_output, new_state = unsafe_decorator(cell)(decoder_input, state)
        new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights)

        with tf.variable_scope('attention_output_projection'):
            output = linear_unsafe([cell_output, new_attns], output_size, True)

        beam_tensors = namedtuple('beam_tensors', 'state new_state attn_weights new_attn_weights attns new_attns')
        return output, beam_tensors(state, new_state, attn_weights, new_attn_weights, attns, new_attns)
Exemplo n.º 10
0
def attention_decoder(decoder_inputs, initial_state, attention_states, encoders, decoder, decoder_input_length=None,
                      output_projection=None, dropout=None, feed_previous=0.0, **kwargs):
    """
    :param decoder_inputs: tensor of shape (batch_size, output_length)
    :param initial_state: initial state of the decoder (usually the final state of the encoder),
      as a tensor of shape (batch_size, initial_state_size). This state is mapped to the
      correct state size for the decoder.
    :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size),
      usually the encoder outputs (one tensor for each encoder).
    :param encoders: configuration of the encoders
    :param decoder: configuration of the decoder
    :param decoder_input_length:
    :param output_projection: None if no softmax sampling, or tuple (weight matrix, bias vector)
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output
      instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training)
    :return:
      outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size)
      attention weights as a tensor of shape (output_length, encoders, batch_size, input_length)
    """
    # TODO: dropout instead of keep probability
    if decoder.get('embedding') is not None:
        embedding_initializer = decoder.embedding
        embedding_shape = None
    else:
        embedding_initializer = None
        embedding_shape = [decoder.vocab_size, decoder.embedding_size]

    with tf.device('/cpu:0'):
        embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape,
                                        initializer=embedding_initializer)

    if decoder.use_lstm:
        cell = rnn_cell.BasicLSTMCell(decoder.cell_size, state_is_tuple=False)
    else:
        cell = rnn_cell.GRUCell(decoder.cell_size)

    if dropout is not None:
        cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout)

    if decoder.layers > 1:
        cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections)

    if output_projection is None:
        output_size = decoder.vocab_size
    else:
        output_size = cell.output_size
        proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32)
        proj_weights.get_shape().assert_is_compatible_with([cell.output_size, decoder.vocab_size])
        proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32)
        proj_biases.get_shape().assert_is_compatible_with([decoder.vocab_size])

    with tf.variable_scope('decoder_{}'.format(decoder.name)):
        def extract_argmax_and_embed(prev):
            if output_projection is not None:
                prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
            return emb_prev

        if embedding is not None:
            time_steps = tf.shape(decoder_inputs)[0]
            batch_size = tf.shape(decoder_inputs)[1]
            flat_inputs = tf.reshape(decoder_inputs, [tf.mul(batch_size, time_steps)])
            flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs)
            decoder_inputs = tf.reshape(flat_inputs,
                                        tf.pack([time_steps, batch_size, flat_inputs.get_shape()[1].value]))

        attn_lengths = [tf.shape(states)[1] for states in attention_states]
        attn_size = sum(states.get_shape()[2].value for states in attention_states)

        hidden_states = [tf.expand_dims(states, 2) for states in attention_states]
        attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders)

        if dropout is not None:
            initial_state = tf.nn.dropout(initial_state, dropout)

        state = tf.nn.tanh(
            linear_unsafe(initial_state, cell.state_size, False, scope='initial_state_projection')
        )

        sequence_length = decoder_input_length
        if sequence_length is not None:
            sequence_length = tf.to_int32(sequence_length)
            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

        time = tf.constant(0, dtype=tf.int32, name='time')

        input_shape = tf.shape(decoder_inputs)
        time_steps = input_shape[0]
        batch_size = input_shape[1]
        state_size = cell.state_size

        zero_output = tf.zeros(tf.pack([batch_size, cell.output_size]), tf.float32)

        output_ta = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False)
        input_ta = tf.TensorArray(dtype=tf.float32, size=time_steps).unpack(decoder_inputs)
        attn_weights_ta = tf.TensorArray(dtype=tf.float32, size=time_steps)
        attention_weights = [tf.zeros(tf.pack([batch_size, length])) for length in attn_lengths]

        attns = tf.zeros(tf.pack([batch_size, attn_size]), dtype=tf.float32)
        attns.set_shape([None, attn_size])

        def _time_step(time, state, _, attn_weights, output_ta_t, attn_weights_ta_t):
            input_t = input_ta.read(time)
            # restore some shape information
            r = tf.random_uniform([])
            input_t = tf.cond(tf.logical_and(time > 0, r < feed_previous),
                              lambda: tf.stop_gradient(extract_argmax_and_embed(output_ta_t.read(time - 1))),
                              lambda: input_t)
            input_t.set_shape(decoder_inputs.get_shape()[1:])
            # the code from TensorFlow used a concatenation of input_t and attns as input here
            # TODO: evaluate the impact of this
            call_cell = lambda: unsafe_decorator(cell)(input_t, state)

            if sequence_length is not None:
                output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                output, new_state = call_cell()

            attn_weights_ta_t = attn_weights_ta_t.write(time, attn_weights)
            # using decoder state instead of decoder output in the attention model seems
            # to give much better results
            new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights)

            with tf.variable_scope('attention_output_projection'):  # this can take a lot of memory
                output = linear_unsafe([output, new_attns], output_size, True)

            output_ta_t = output_ta_t.write(time, output)
            return time + 1, new_state, new_attns, new_attn_weights, output_ta_t, attn_weights_ta_t

        _, _, _, _, output_final_ta, attn_weights_final = tf.while_loop(
            cond=lambda time, *_: time < time_steps,
            body=_time_step,
            loop_vars=(time, state, attns, attention_weights, output_ta, attn_weights_ta),
            parallel_iterations=decoder.parallel_iterations,
            swap_memory=decoder.swap_memory)

        outputs = output_final_ta.pack()

        # shape (time_steps, encoders, batch_size, input_time_steps)
        attention_weights = tf.slice(attn_weights_final.pack(), [1, 0, 0, 0], [-1, -1, -1, -1])
        return outputs, attention_weights