Пример #1
0
    def get_cell(input_size=None, reuse=False):
        cells = []

        for j in range(decoder.layers):
            input_size_ = input_size if j == 0 else decoder.cell_size

            if decoder.cell_type.lower() == 'lstm':
                cell = CellWrapper(BasicLSTMCell(decoder.cell_size, reuse=reuse))
            elif decoder.cell_type.lower() == 'dropoutgru':
                cell = DropoutGRUCell(decoder.cell_size, reuse=reuse, layer_norm=decoder.layer_norm,
                                      input_size=input_size_, input_keep_prob=decoder.rnn_input_keep_prob,
                                      state_keep_prob=decoder.rnn_state_keep_prob)
            else:
                cell = GRUCell(decoder.cell_size, reuse=reuse, layer_norm=decoder.layer_norm)

            if decoder.use_dropout and decoder.cell_type.lower() != 'dropoutgru':
                cell = DropoutWrapper(cell, input_keep_prob=decoder.rnn_input_keep_prob,
                                      output_keep_prob=decoder.rnn_output_keep_prob,
                                      state_keep_prob=decoder.rnn_state_keep_prob,
                                      variational_recurrent=decoder.pervasive_dropout,
                                      dtype=tf.float32, input_size=input_size_)
            cells.append(cell)

        if len(cells) == 1:
            return cells[0]
        else:
            return CellWrapper(MultiRNNCell(cells))
Пример #2
0
            def cell():
                if encoder.use_lstm:
                    cell = BasicLSTMCell(encoder.cell_size,
                                         state_is_tuple=False)
                else:
                    cell = GRUCell(encoder.cell_size,
                                   initializer=orthogonal_initializer())

                if dropout is not None:
                    cell = DropoutWrapper(cell, input_keep_prob=dropout)
                return cell
Пример #3
0
            def get_cell(input_size=None, reuse=False):
                if encoder.cell_type.lower() == 'lstm':
                    cell = CellWrapper(BasicLSTMCell(encoder.cell_size, reuse=reuse))
                elif encoder.cell_type.lower() == 'dropoutgru':
                    cell = DropoutGRUCell(encoder.cell_size, reuse=reuse, layer_norm=encoder.layer_norm,
                                          input_size=input_size, input_keep_prob=encoder.rnn_input_keep_prob,
                                          state_keep_prob=encoder.rnn_state_keep_prob)
                else:
                    cell = GRUCell(encoder.cell_size, reuse=reuse, layer_norm=encoder.layer_norm)

                if encoder.use_dropout and encoder.cell_type.lower() != 'dropoutgru':
                    cell = DropoutWrapper(cell, input_keep_prob=encoder.rnn_input_keep_prob,
                                          output_keep_prob=encoder.rnn_output_keep_prob,
                                          state_keep_prob=encoder.rnn_state_keep_prob,
                                          variational_recurrent=encoder.pervasive_dropout,
                                          dtype=tf.float32, input_size=input_size)
                return cell
Пример #4
0
def attention_decoder(targets, initial_state, attention_states, encoders, decoder, encoder_input_length,
                      decoder_input_length=None, dropout=None, feed_previous=0.0, feed_argmax=True,
                      **kwargs):
    """
    :param targets: tensor of shape (output_length, batch_size)
    :param initial_state: initial state of the decoder (usually the final state of the encoder),
      as a tensor of shape (batch_size, initial_state_size). This state is mapped to the
      correct state size for the decoder.
    :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size),
      usually the encoder outputs (one tensor for each encoder).
    :param encoders: configuration of the encoders
    :param decoder: configuration of the decoder
    :param decoder_input_length:
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output
      instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training)
    :return:
      outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size)
      attention weights as a tensor of shape (output_length, encoders, batch_size, input_length)
    """
    # TODO: dropout instead of keep probability
    assert decoder.cell_size % 2 == 0, 'cell size must be a multiple of 2'   # because of maxout

    decoder_inputs = targets[:-1,:]  # starts with BOS

    if decoder.get('embedding') is not None:
        initializer = decoder.embedding
        embedding_shape = None
    else:
        initializer = None
        embedding_shape = [decoder.vocab_size, decoder.embedding_size]

    with tf.device('/cpu:0'):
        embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape,
                                        initializer=initializer)

    if decoder.use_lstm:
        cell = BasicLSTMCell(decoder.cell_size, state_is_tuple=False)
    else:
        cell = GRUCell(decoder.cell_size, initializer=orthogonal_initializer())

    if dropout is not None:
        cell = DropoutWrapper(cell, input_keep_prob=dropout)

    if decoder.layers > 1:
        cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections)

    with tf.variable_scope('decoder_{}'.format(decoder.name)):
        def embed(input_):
            if embedding is not None:
                return tf.nn.embedding_lookup(embedding, input_)
            else:
                return input_

        hidden_states = [tf.expand_dims(states, 2) for states in attention_states]
        attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders,
                                       encoder_input_length=encoder_input_length)

        input_shape = tf.shape(decoder_inputs)
        time_steps = input_shape[0]
        batch_size = input_shape[1]
        output_size = decoder.vocab_size
        state_size = cell.state_size

        if initial_state is not None:
            if dropout is not None:
                initial_state = tf.nn.dropout(initial_state, dropout)

            state = tf.nn.tanh(
                linear_unsafe(initial_state, state_size, True, scope='initial_state_projection')
            )
        else:
            # if not initial state, initialize with zeroes (this is the case for MIXER)
            state = tf.zeros([batch_size, state_size], dtype=tf.float32)

        sequence_length = decoder_input_length
        if sequence_length is not None:
            sequence_length = tf.to_int32(sequence_length)
            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

        time = tf.constant(0, dtype=tf.int32, name='time')
        zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), tf.float32)

        proj_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False)
        decoder_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps)

        inputs = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False).unstack(
                                tf.cast(decoder_inputs, tf.int64))
        samples = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False)
        states = tf.TensorArray(dtype=tf.float32, size=time_steps)

        attn_lengths = [tf.shape(states)[1] for states in attention_states]

        weights = tf.TensorArray(dtype=tf.float32, size=time_steps)
        initial_weights = [tf.zeros(tf.stack([batch_size, length])) for length in attn_lengths]

        output = tf.zeros(tf.stack([batch_size, cell.output_size]), dtype=tf.float32)

        initial_input = embed(inputs.read(0))   # first symbol is BOS

        def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights,
                       prev_weights):
            context_vector, new_weights = attention_(state, prev_weights=prev_weights)
            weights = weights.write(time, new_weights)

            # FIXME use `output` or `state` here?
            output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout')
            output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2)
            output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0')
            decoder_outputs = decoder_outputs.write(time, output_)
            output_ = linear_unsafe(output_, output_size, True, scope='softmax1')
            proj_outputs = proj_outputs.write(time, output_)

            argmax = lambda: tf.argmax(output_, 1)
            softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1),
                                         axis=1)
            target = lambda: inputs.read(time + 1)

            sample = tf.case([
                (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target),
                (tf.logical_not(feed_argmax), softmax)],
                default=argmax)   # default case is useful for beam-search

            sample.set_shape([None])
            sample = tf.stop_gradient(sample)

            samples = samples.write(time, sample)
            input_ = embed(sample)

            x = tf.concat([input_, context_vector], 1)
            call_cell = lambda: unsafe_decorator(cell)(x, state)

            if sequence_length is not None:
                new_output, new_state = rnn._rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
            else:
                new_output, new_state = call_cell()

            states = states.write(time, new_state)

            return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights,
                    new_weights)

        _, _, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, _ = tf.while_loop(
            cond=lambda time, *_: time < time_steps,
            body=_time_step,
            loop_vars=(time, initial_input, state, output, proj_outputs, decoder_outputs, samples, weights, states,
                       initial_weights),
            parallel_iterations=decoder.parallel_iterations,
            swap_memory=decoder.swap_memory)

        proj_outputs = proj_outputs.stack()
        decoder_outputs = decoder_outputs.stack()
        samples = samples.stack()
        weights = weights.stack()  # batch_size, encoders, output time, input time
        states = states.stack()

        # weights = tf.Print(weights, [weights[:,0]], summarize=20)
        # tf.control_dependencies()

        beam_tensors = namedtuple('beam_tensors', 'state new_state output new_output')
        return (proj_outputs, weights, decoder_outputs, beam_tensors(state, new_state, output, new_output),
                samples, states)
Пример #5
0
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs):
    """
    Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`.
    The result is a list of the outputs produced by those encoders (for each time-step), and their final state.

    :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder)
    :param encoders: list of encoder configurations
    :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder)
    :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout)
    :return:
      encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size)
      encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes)
    """
    assert len(encoder_inputs) == len(encoders)
    encoder_states = []
    encoder_outputs = []

    # create embeddings in the global scope (allows sharing between encoder and decoder)
    embedding_variables = []
    for encoder in encoders:
        # inputs are token ids, which need to be mapped to vectors (embeddings)
        if not encoder.binary:
            if encoder.get('embedding') is not None:
                initializer = encoder.embedding
                embedding_shape = None
            else:
                # initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3))
                initializer = None
                embedding_shape = [encoder.vocab_size, encoder.embedding_size]

            with tf.device('/cpu:0'):
                embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape,
                                                initializer=initializer)
            embedding_variables.append(embedding)
        else:  # do nothing: inputs are already vectors
            embedding_variables.append(None)

    for i, encoder in enumerate(encoders):
        with tf.variable_scope('encoder_{}'.format(encoder.name)):
            encoder_inputs_ = encoder_inputs[i]
            encoder_input_length_ = encoder_input_length[i]

            # TODO: use state_is_tuple=True
            if encoder.use_lstm:
                cell = BasicLSTMCell(encoder.cell_size, state_is_tuple=False)
            else:
                cell = GRUCell(encoder.cell_size, initializer=orthogonal_initializer())

            if dropout is not None:
                cell = DropoutWrapper(cell, input_keep_prob=dropout)

            embedding = embedding_variables[i]

            if embedding is not None or encoder.input_layers:
                batch_size = tf.shape(encoder_inputs_)[0]  # TODO: fix this time major stuff
                time_steps = tf.shape(encoder_inputs_)[1]

                if embedding is None:
                    size = encoder_inputs_.get_shape()[2].value
                    flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps), size])
                else:
                    flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps)])
                    flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs)

                if encoder.input_layers:
                    for j, size in enumerate(encoder.input_layers):
                        name = 'input_layer_{}'.format(j)
                        flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name))
                        if dropout is not None:
                            flat_inputs = tf.nn.dropout(flat_inputs, dropout)

                encoder_inputs_ = tf.reshape(flat_inputs,
                                             tf.stack([batch_size, time_steps, flat_inputs.get_shape()[1].value]))

            # Contrary to Theano's RNN implementation, states after the sequence length are zero
            # (while Theano repeats last state)
            sequence_length = encoder_input_length_   # TODO
            parameters = dict(
                inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling,
                pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory,
                parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections,
                trainable_initial_state=True
            )

            if encoder.bidir:
                encoder_outputs_, _, _ = multi_bidirectional_rnn_unsafe(
                    cells=[(cell, cell)] * encoder.layers, **parameters)
                # Like Bahdanau et al., we use the first annotation h_1 of the backward encoder
                encoder_state_ = encoder_outputs_[:, 0, encoder.cell_size:]
                # TODO: if multiple layers, combine last states with a Maxout layer
            else:
                encoder_outputs_, encoder_state_ = multi_rnn_unsafe(
                    cells=[cell] * encoder.layers, **parameters)
                encoder_state_ = encoder_outputs_[:, -1, :]

            encoder_outputs.append(encoder_outputs_)
            encoder_states.append(encoder_state_)

    encoder_state = tf.concat(encoder_states, 1)
    return encoder_outputs, encoder_state