Example #1
0
def LSTMWithAttention(
    model,
    decoder_inputs,
    decoder_input_lengths,
    initial_decoder_hidden_state,
    initial_decoder_cell_state,
    initial_attention_weighted_encoder_context,
    encoder_output_dim,
    encoder_outputs,
    decoder_input_dim,
    decoder_state_dim,
    scope,
    attention_type=AttentionType.Regular,
    outputs_with_grads=(0, 4),
    weighted_encoder_outputs=None,
    lstm_memory_optimization=False,
    attention_memory_optimization=False,
    forget_bias=0.0,
):
    '''
    Adds a LSTM with attention mechanism to a model.

    The implementation is based on https://arxiv.org/abs/1409.0473, with
    a small difference in the order
    how we compute new attention context and new hidden state, similarly to
    https://arxiv.org/abs/1508.04025.

    The model uses encoder-decoder naming conventions,
    where the decoder is the sequence the op is iterating over,
    while computing the attention context over the encoder.

    model: CNNModelHelper object new operators would be added to

    decoder_inputs: the input sequence in a format T x N x D
    where T is sequence size, N - batch size and D - input dimention

    decoder_input_lengths: blob containing sequence lengths
    which would be passed to LSTMUnit operator

    initial_decoder_hidden_state: initial hidden state of LSTM

    initial_decoder_cell_state: initial cell state of LSTM

    initial_attention_weighted_encoder_context: initial attention context

    encoder_output_dim: dimension of encoder outputs

    encoder_outputs: the sequence, on which we compute the attention context
    at every iteration

    decoder_input_dim: input dimention (last dimension on decoder_inputs)

    decoder_state_dim: size of hidden states of LSTM

    attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
    Determines which type of attention mechanism to use.

    outputs_with_grads : position indices of output blobs which will receive
    external error gradient during backpropagation

    weighted_encoder_outputs: encoder outputs to be used to compute attention
    weights. In the basic case it's just linear transformation of
    encoder outputs (that the default, when weighted_encoder_outputs is None).
    However, it can be something more complicated - like a separate
    encoder network (for example, in case of convolutional encoder)

    lstm_memory_optimization: recompute LSTM activations on backward pass, so
                 we don't need to store their values in forward passes

    attention_memory_optimization: recompute attention for backward pass
    '''
    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        return "{}/{}".format(str(scope), str(name))

    decoder_inputs = model.FC(
        decoder_inputs,
        s('i2h'),
        dim_in=decoder_input_dim,
        dim_out=4 * decoder_state_dim,
        axis=2,
    )
    # [batch_size, encoder_output_dim, encoder_length]
    encoder_outputs_transposed = model.Transpose(
        encoder_outputs,
        s('encoder_outputs_transposed'),
        axes=[1, 2, 0],
    )
    if weighted_encoder_outputs is None:
        weighted_encoder_outputs = model.FC(
            encoder_outputs,
            s('weighted_encoder_outputs'),
            dim_in=encoder_output_dim,
            dim_out=encoder_output_dim,
            axis=2,
        )
    step_model = CNNModelHelper(
        name='lstm_with_attention_cell',
        param_model=model,
    )
    (
        input_t,
        timestep,
        cell_t_prev,
        hidden_t_prev,
        attention_weighted_encoder_context_t_prev,
    ) = (step_model.net.AddScopedExternalInputs(
        'input_t',
        'timestep',
        'cell_t_prev',
        'hidden_t_prev',
        'attention_weighted_encoder_context_t_prev',
    ))
    step_model.net.AddExternalInputs(encoder_outputs_transposed,
                                     weighted_encoder_outputs)

    gates_concatenated_input_t, _ = step_model.net.Concat(
        [hidden_t_prev, attention_weighted_encoder_context_t_prev],
        [
            s('gates_concatenated_input_t'),
            s('_gates_concatenated_input_t_concat_dims'),
        ],
        axis=2,
    )
    gates_t = step_model.FC(
        gates_concatenated_input_t,
        s('gates_t'),
        dim_in=decoder_state_dim + encoder_output_dim,
        dim_out=4 * decoder_state_dim,
        axis=2,
    )
    step_model.net.Sum([gates_t, input_t], gates_t)

    hidden_t_intermediate, cell_t = step_model.net.LSTMUnit(
        [hidden_t_prev, cell_t_prev, gates_t, decoder_input_lengths, timestep],
        ['hidden_t_intermediate', s('cell_t')],
        forget_bias=forget_bias,
    )
    if attention_type == AttentionType.Recurrent:
        attention_weighted_encoder_context_t, _, attention_blobs = apply_recurrent_attention(
            model=step_model,
            encoder_output_dim=encoder_output_dim,
            encoder_outputs_transposed=encoder_outputs_transposed,
            weighted_encoder_outputs=weighted_encoder_outputs,
            decoder_hidden_state_t=hidden_t_intermediate,
            decoder_hidden_state_dim=decoder_state_dim,
            scope=scope,
            attention_weighted_encoder_context_t_prev=(
                attention_weighted_encoder_context_t_prev),
        )
    else:
        attention_weighted_encoder_context_t, _, attention_blobs = apply_regular_attention(
            model=step_model,
            encoder_output_dim=encoder_output_dim,
            encoder_outputs_transposed=encoder_outputs_transposed,
            weighted_encoder_outputs=weighted_encoder_outputs,
            decoder_hidden_state_t=hidden_t_intermediate,
            decoder_hidden_state_dim=decoder_state_dim,
            scope=scope,
        )
    hidden_t = step_model.Copy(hidden_t_intermediate, s('hidden_t'))
    step_model.net.AddExternalOutputs(
        cell_t,
        hidden_t,
        attention_weighted_encoder_context_t,
    )
    recompute_blobs = []
    if attention_memory_optimization:
        recompute_blobs.extend(attention_blobs)
    if lstm_memory_optimization:
        recompute_blobs.extend([gates_t])

    return recurrent_net(
        net=model.net,
        cell_net=step_model.net,
        inputs=[
            (input_t, decoder_inputs),
        ],
        initial_cell_inputs=[
            (hidden_t_prev, initial_decoder_hidden_state),
            (cell_t_prev, initial_decoder_cell_state),
            (
                attention_weighted_encoder_context_t_prev,
                initial_attention_weighted_encoder_context,
            ),
        ],
        links={
            hidden_t_prev:
            hidden_t,
            cell_t_prev:
            cell_t,
            attention_weighted_encoder_context_t_prev:
            (attention_weighted_encoder_context_t),
        },
        timestep=timestep,
        scope=scope,
        outputs_with_grads=outputs_with_grads,
        recompute_blobs_on_backward=recompute_blobs,
    )
Example #2
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
    ):
        decoder_prev_states = states[:-1]
        attention_weighted_encoder_context_t_prev = states[-1]

        decoder_states = self.decoder_cell._apply(
            model,
            input_t,
            seq_lengths,
            decoder_prev_states,
            timestep,
            extra_inputs=[(
                attention_weighted_encoder_context_t_prev,
                self.encoder_output_dim,
            )],
        )

        hidden_t_intermediate = \
            decoder_states[self.decoder_cell.output_state_index()]

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev),
            )
        else:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
            )

        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)

        hidden_t = model.Copy(
            hidden_t_intermediate,
            self.scope('hidden_t_external'),
        )
        output = list(decoder_states) + [attention_weighted_encoder_context_t]
        output[self.decoder_cell.output_state_index()] = hidden_t
        model.net.AddExternalOutputs(*output)

        return output
Example #3
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
    ):
        decoder_prev_states = states[:-1]
        attention_weighted_encoder_context_t_prev = states[-1]

        decoder_states = self.decoder_cell._apply(
            model,
            input_t,
            seq_lengths,
            decoder_prev_states,
            timestep,
            extra_inputs=[(
                attention_weighted_encoder_context_t_prev,
                self.encoder_output_dim,
            )],
        )
        # TODO: we should use prepare_output method here,
        # but because of the recurrent_net's edge case with we
        # have to know which states is being used to compute attention.
        # So instead of manupulating with output of the cell,
        # we have to work with the output state directly.
        # In other words, if output of decoder_cell is not equal to
        # one of decoder_cell states (the one - get_output_state_index()),
        # then this logic is broken. Right now, that can happen if
        # there is a dropout, so we explicitly check dropout has been disabled.
        assert self.decoder_cell.dropout_ratio is None

        hidden_t_intermediate = \
            decoder_states[self.decoder_cell.get_output_state_index()]

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev),
            )
        else:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
            )

        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)

        hidden_t = model.Copy(
            hidden_t_intermediate,
            self.scope('hidden_t_external'),
        )
        output = list(decoder_states) + [attention_weighted_encoder_context_t]
        output[self.decoder_cell.get_output_state_index()] = hidden_t
        model.net.AddExternalOutputs(*output)

        return output
Example #4
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
    ):
        (
            hidden_t_prev,
            cell_t_prev,
            attention_weighted_encoder_context_t_prev,
        ) = states

        gates_concatenated_input_t, _ = model.net.Concat(
            [hidden_t_prev, attention_weighted_encoder_context_t_prev],
            [
                self.scope('gates_concatenated_input_t'),
                self.scope('_gates_concatenated_input_t_concat_dims'),
            ],
            axis=2,
        )
        # hU^T
        # Shape: [1, batch_size, 4 * hidden_size]
        prev_t = model.FC(
            gates_concatenated_input_t,
            self.scope('prev_t'),
            dim_in=self.decoder_state_dim + self.encoder_output_dim,
            dim_out=4 * self.decoder_state_dim,
            axis=2,
        )
        # defining MI parameters
        alpha = model.param_init_net.ConstantFill(
            [], [self.scope('alpha')],
            shape=[4 * self.decoder_state_dim],
            value=1.0)
        beta1 = model.param_init_net.ConstantFill(
            [], [self.scope('beta1')],
            shape=[4 * self.decoder_state_dim],
            value=1.0)
        beta2 = model.param_init_net.ConstantFill(
            [], [self.scope('beta2')],
            shape=[4 * self.decoder_state_dim],
            value=1.0)
        b = model.param_init_net.ConstantFill(
            [], [self.scope('b')],
            shape=[4 * self.decoder_state_dim],
            value=0.0)
        model.params.extend([alpha, beta1, beta2, b])
        # alpha * (xW^T * hU^T)
        # Shape: [1, batch_size, 4 * hidden_size]
        alpha_tdash = model.net.Mul([prev_t, input_t],
                                    self.scope('alpha_tdash'))
        # Shape: [batch_size, 4 * hidden_size]
        alpha_tdash_rs, _ = model.net.Reshape(
            alpha_tdash,
            [
                self.scope('alpha_tdash_rs'),
                self.scope('alpha_tdash_old_shape')
            ],
            shape=[-1, 4 * self.decoder_state_dim],
        )
        alpha_t = model.net.Mul([alpha_tdash_rs, alpha],
                                self.scope('alpha_t'),
                                broadcast=1,
                                use_grad_hack=1)
        # beta1 * hU^T
        # Shape: [batch_size, 4 * hidden_size]
        prev_t_rs, _ = model.net.Reshape(
            prev_t,
            [self.scope('prev_t_rs'),
             self.scope('prev_t_old_shape')],
            shape=[-1, 4 * self.decoder_state_dim],
        )
        beta1_t = model.net.Mul([prev_t_rs, beta1],
                                self.scope('beta1_t'),
                                broadcast=1,
                                use_grad_hack=1)
        # beta2 * xW^T
        # Shape: [batch_szie, 4 * hidden_size]
        input_t_rs, _ = model.net.Reshape(
            input_t,
            [self.scope('input_t_rs'),
             self.scope('input_t_old_shape')],
            shape=[-1, 4 * self.decoder_state_dim],
        )
        beta2_t = model.net.Mul([input_t_rs, beta2],
                                self.scope('beta2_t'),
                                broadcast=1,
                                use_grad_hack=1)
        # Add 'em all up
        gates_tdash = model.net.Sum([alpha_t, beta1_t, beta2_t],
                                    self.scope('gates_tdash'))
        gates_t = model.net.Add([gates_tdash, b],
                                self.scope('gates_t'),
                                broadcast=1,
                                use_grad_hack=1)
        # # Shape: [1, batch_size, 4 * hidden_size]
        gates_t_rs, _ = model.net.Reshape(
            gates_t,
            [self.scope('gates_t_rs'),
             self.scope('gates_t_old_shape')],
            shape=[1, -1, 4 * self.decoder_state_dim],
        )

        hidden_t_intermediate, cell_t = model.net.LSTMUnit(
            [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
            [self.scope('hidden_t_intermediate'),
             self.scope('cell_t')],
        )

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                self.recompute_blobs,
            ) = (apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev),
            ))
        else:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                self.recompute_blobs,
            ) = (apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
            ))
        hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
        model.net.AddExternalOutputs(
            cell_t,
            hidden_t,
            attention_weighted_encoder_context_t,
        )
        return hidden_t, cell_t, attention_weighted_encoder_context_t
Example #5
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
    ):
        (
            hidden_t_prev,
            cell_t_prev,
            attention_weighted_encoder_context_t_prev,
        ) = states

        gates_concatenated_input_t, _ = model.net.Concat(
            [hidden_t_prev, attention_weighted_encoder_context_t_prev],
            [
                self.scope('gates_concatenated_input_t'),
                self.scope('_gates_concatenated_input_t_concat_dims'),
            ],
            axis=2,
        )
        gates_t = model.FC(
            gates_concatenated_input_t,
            self.scope('gates_t'),
            dim_in=self.decoder_state_dim + self.encoder_output_dim,
            dim_out=4 * self.decoder_state_dim,
            axis=2,
        )
        model.net.Sum([gates_t, input_t], gates_t)

        hidden_t_intermediate, cell_t = model.net.LSTMUnit(
            [
                hidden_t_prev,
                cell_t_prev,
                gates_t,
                seq_lengths,
                timestep,
            ],
            ['hidden_t_intermediate',
             self.scope('cell_t')],
        )
        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev),
            )
        else:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
            )
        hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
        model.net.AddExternalOutputs(
            cell_t,
            hidden_t,
            attention_weighted_encoder_context_t,
        )
        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)
        if self.lstm_memory_optimization:
            self.recompute_blobs.append(gates_t)

        return hidden_t, cell_t, attention_weighted_encoder_context_t
Example #6
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        decoder_prev_states = states[:-1]
        attention_weighted_encoder_context_t_prev = states[-1]

        assert extra_inputs is None

        decoder_states = self.decoder_cell._apply(
            model,
            input_t,
            seq_lengths,
            decoder_prev_states,
            timestep,
            extra_inputs=[(
                attention_weighted_encoder_context_t_prev,
                self.encoder_output_dim,
            )],
        )

        self.hidden_t_intermediate = self.decoder_cell._prepare_output(
            model,
            decoder_states,
        )

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev
                ),
                encoder_lengths=self.encoder_lengths,
            )
        elif self.attention_type == AttentionType.Regular:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
            )
        elif self.attention_type == AttentionType.Dot:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_dot_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
            )
        else:
            raise Exception('Attention type {} not implemented'.format(
                self.attention_type
            ))

        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)

        output = list(decoder_states) + [attention_weighted_encoder_context_t]
        output[self.decoder_cell.get_output_state_index()] = model.Copy(
            output[self.decoder_cell.get_output_state_index()],
            self.scope('hidden_t_external'),
        )
        model.net.AddExternalOutputs(*output)

        return output
Example #7
0
    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        decoder_prev_states = states[:-1]
        attention_weighted_encoder_context_t_prev = states[-1]

        assert extra_inputs is None

        decoder_states = self.decoder_cell._apply(
            model,
            input_t,
            seq_lengths,
            decoder_prev_states,
            timestep,
            extra_inputs=[(
                attention_weighted_encoder_context_t_prev,
                self.encoder_output_dim,
            )],
        )

        self.hidden_t_intermediate = self.decoder_cell._prepare_output(
            model,
            decoder_states,
        )

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev
                ),
                encoder_lengths=self.encoder_lengths,
            )
        else:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
            )

        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)

        output = list(decoder_states) + [attention_weighted_encoder_context_t]
        output[self.decoder_cell.get_output_state_index()] = model.Copy(
            output[self.decoder_cell.get_output_state_index()],
            self.scope('hidden_t_external'),
        )
        model.net.AddExternalOutputs(*output)

        return output