def LSTMWithAttention( model, decoder_inputs, decoder_input_lengths, initial_decoder_hidden_state, initial_decoder_cell_state, initial_attention_weighted_encoder_context, encoder_output_dim, encoder_outputs, decoder_input_dim, decoder_state_dim, scope, attention_type=AttentionType.Regular, outputs_with_grads=(0, 4), weighted_encoder_outputs=None, lstm_memory_optimization=False, attention_memory_optimization=False, forget_bias=0.0, ): ''' Adds a LSTM with attention mechanism to a model. The implementation is based on https://arxiv.org/abs/1409.0473, with a small difference in the order how we compute new attention context and new hidden state, similarly to https://arxiv.org/abs/1508.04025. The model uses encoder-decoder naming conventions, where the decoder is the sequence the op is iterating over, while computing the attention context over the encoder. model: CNNModelHelper object new operators would be added to decoder_inputs: the input sequence in a format T x N x D where T is sequence size, N - batch size and D - input dimention decoder_input_lengths: blob containing sequence lengths which would be passed to LSTMUnit operator initial_decoder_hidden_state: initial hidden state of LSTM initial_decoder_cell_state: initial cell state of LSTM initial_attention_weighted_encoder_context: initial attention context encoder_output_dim: dimension of encoder outputs encoder_outputs: the sequence, on which we compute the attention context at every iteration decoder_input_dim: input dimention (last dimension on decoder_inputs) decoder_state_dim: size of hidden states of LSTM attention_type: One of: AttentionType.Regular, AttentionType.Recurrent. Determines which type of attention mechanism to use. outputs_with_grads : position indices of output blobs which will receive external error gradient during backpropagation weighted_encoder_outputs: encoder outputs to be used to compute attention weights. In the basic case it's just linear transformation of encoder outputs (that the default, when weighted_encoder_outputs is None). However, it can be something more complicated - like a separate encoder network (for example, in case of convolutional encoder) lstm_memory_optimization: recompute LSTM activations on backward pass, so we don't need to store their values in forward passes attention_memory_optimization: recompute attention for backward pass ''' def s(name): # We have to manually scope due to our internal/external blob # relationships. return "{}/{}".format(str(scope), str(name)) decoder_inputs = model.FC( decoder_inputs, s('i2h'), dim_in=decoder_input_dim, dim_out=4 * decoder_state_dim, axis=2, ) # [batch_size, encoder_output_dim, encoder_length] encoder_outputs_transposed = model.Transpose( encoder_outputs, s('encoder_outputs_transposed'), axes=[1, 2, 0], ) if weighted_encoder_outputs is None: weighted_encoder_outputs = model.FC( encoder_outputs, s('weighted_encoder_outputs'), dim_in=encoder_output_dim, dim_out=encoder_output_dim, axis=2, ) step_model = CNNModelHelper( name='lstm_with_attention_cell', param_model=model, ) ( input_t, timestep, cell_t_prev, hidden_t_prev, attention_weighted_encoder_context_t_prev, ) = (step_model.net.AddScopedExternalInputs( 'input_t', 'timestep', 'cell_t_prev', 'hidden_t_prev', 'attention_weighted_encoder_context_t_prev', )) step_model.net.AddExternalInputs(encoder_outputs_transposed, weighted_encoder_outputs) gates_concatenated_input_t, _ = step_model.net.Concat( [hidden_t_prev, attention_weighted_encoder_context_t_prev], [ s('gates_concatenated_input_t'), s('_gates_concatenated_input_t_concat_dims'), ], axis=2, ) gates_t = step_model.FC( gates_concatenated_input_t, s('gates_t'), dim_in=decoder_state_dim + encoder_output_dim, dim_out=4 * decoder_state_dim, axis=2, ) step_model.net.Sum([gates_t, input_t], gates_t) hidden_t_intermediate, cell_t = step_model.net.LSTMUnit( [hidden_t_prev, cell_t_prev, gates_t, decoder_input_lengths, timestep], ['hidden_t_intermediate', s('cell_t')], forget_bias=forget_bias, ) if attention_type == AttentionType.Recurrent: attention_weighted_encoder_context_t, _, attention_blobs = apply_recurrent_attention( model=step_model, encoder_output_dim=encoder_output_dim, encoder_outputs_transposed=encoder_outputs_transposed, weighted_encoder_outputs=weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=decoder_state_dim, scope=scope, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev), ) else: attention_weighted_encoder_context_t, _, attention_blobs = apply_regular_attention( model=step_model, encoder_output_dim=encoder_output_dim, encoder_outputs_transposed=encoder_outputs_transposed, weighted_encoder_outputs=weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=decoder_state_dim, scope=scope, ) hidden_t = step_model.Copy(hidden_t_intermediate, s('hidden_t')) step_model.net.AddExternalOutputs( cell_t, hidden_t, attention_weighted_encoder_context_t, ) recompute_blobs = [] if attention_memory_optimization: recompute_blobs.extend(attention_blobs) if lstm_memory_optimization: recompute_blobs.extend([gates_t]) return recurrent_net( net=model.net, cell_net=step_model.net, inputs=[ (input_t, decoder_inputs), ], initial_cell_inputs=[ (hidden_t_prev, initial_decoder_hidden_state), (cell_t_prev, initial_decoder_cell_state), ( attention_weighted_encoder_context_t_prev, initial_attention_weighted_encoder_context, ), ], links={ hidden_t_prev: hidden_t, cell_t_prev: cell_t, attention_weighted_encoder_context_t_prev: (attention_weighted_encoder_context_t), }, timestep=timestep, scope=scope, outputs_with_grads=outputs_with_grads, recompute_blobs_on_backward=recompute_blobs, )
def _apply( self, model, input_t, seq_lengths, states, timestep, ): decoder_prev_states = states[:-1] attention_weighted_encoder_context_t_prev = states[-1] decoder_states = self.decoder_cell._apply( model, input_t, seq_lengths, decoder_prev_states, timestep, extra_inputs=[( attention_weighted_encoder_context_t_prev, self.encoder_output_dim, )], ) hidden_t_intermediate = \ decoder_states[self.decoder_cell.output_state_index()] if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev), ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, ) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) hidden_t = model.Copy( hidden_t_intermediate, self.scope('hidden_t_external'), ) output = list(decoder_states) + [attention_weighted_encoder_context_t] output[self.decoder_cell.output_state_index()] = hidden_t model.net.AddExternalOutputs(*output) return output
def _apply( self, model, input_t, seq_lengths, states, timestep, ): decoder_prev_states = states[:-1] attention_weighted_encoder_context_t_prev = states[-1] decoder_states = self.decoder_cell._apply( model, input_t, seq_lengths, decoder_prev_states, timestep, extra_inputs=[( attention_weighted_encoder_context_t_prev, self.encoder_output_dim, )], ) # TODO: we should use prepare_output method here, # but because of the recurrent_net's edge case with we # have to know which states is being used to compute attention. # So instead of manupulating with output of the cell, # we have to work with the output state directly. # In other words, if output of decoder_cell is not equal to # one of decoder_cell states (the one - get_output_state_index()), # then this logic is broken. Right now, that can happen if # there is a dropout, so we explicitly check dropout has been disabled. assert self.decoder_cell.dropout_ratio is None hidden_t_intermediate = \ decoder_states[self.decoder_cell.get_output_state_index()] if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev), ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, ) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) hidden_t = model.Copy( hidden_t_intermediate, self.scope('hidden_t_external'), ) output = list(decoder_states) + [attention_weighted_encoder_context_t] output[self.decoder_cell.get_output_state_index()] = hidden_t model.net.AddExternalOutputs(*output) return output
def _apply( self, model, input_t, seq_lengths, states, timestep, ): ( hidden_t_prev, cell_t_prev, attention_weighted_encoder_context_t_prev, ) = states gates_concatenated_input_t, _ = model.net.Concat( [hidden_t_prev, attention_weighted_encoder_context_t_prev], [ self.scope('gates_concatenated_input_t'), self.scope('_gates_concatenated_input_t_concat_dims'), ], axis=2, ) # hU^T # Shape: [1, batch_size, 4 * hidden_size] prev_t = model.FC( gates_concatenated_input_t, self.scope('prev_t'), dim_in=self.decoder_state_dim + self.encoder_output_dim, dim_out=4 * self.decoder_state_dim, axis=2, ) # defining MI parameters alpha = model.param_init_net.ConstantFill( [], [self.scope('alpha')], shape=[4 * self.decoder_state_dim], value=1.0) beta1 = model.param_init_net.ConstantFill( [], [self.scope('beta1')], shape=[4 * self.decoder_state_dim], value=1.0) beta2 = model.param_init_net.ConstantFill( [], [self.scope('beta2')], shape=[4 * self.decoder_state_dim], value=1.0) b = model.param_init_net.ConstantFill( [], [self.scope('b')], shape=[4 * self.decoder_state_dim], value=0.0) model.params.extend([alpha, beta1, beta2, b]) # alpha * (xW^T * hU^T) # Shape: [1, batch_size, 4 * hidden_size] alpha_tdash = model.net.Mul([prev_t, input_t], self.scope('alpha_tdash')) # Shape: [batch_size, 4 * hidden_size] alpha_tdash_rs, _ = model.net.Reshape( alpha_tdash, [ self.scope('alpha_tdash_rs'), self.scope('alpha_tdash_old_shape') ], shape=[-1, 4 * self.decoder_state_dim], ) alpha_t = model.net.Mul([alpha_tdash_rs, alpha], self.scope('alpha_t'), broadcast=1, use_grad_hack=1) # beta1 * hU^T # Shape: [batch_size, 4 * hidden_size] prev_t_rs, _ = model.net.Reshape( prev_t, [self.scope('prev_t_rs'), self.scope('prev_t_old_shape')], shape=[-1, 4 * self.decoder_state_dim], ) beta1_t = model.net.Mul([prev_t_rs, beta1], self.scope('beta1_t'), broadcast=1, use_grad_hack=1) # beta2 * xW^T # Shape: [batch_szie, 4 * hidden_size] input_t_rs, _ = model.net.Reshape( input_t, [self.scope('input_t_rs'), self.scope('input_t_old_shape')], shape=[-1, 4 * self.decoder_state_dim], ) beta2_t = model.net.Mul([input_t_rs, beta2], self.scope('beta2_t'), broadcast=1, use_grad_hack=1) # Add 'em all up gates_tdash = model.net.Sum([alpha_t, beta1_t, beta2_t], self.scope('gates_tdash')) gates_t = model.net.Add([gates_tdash, b], self.scope('gates_t'), broadcast=1, use_grad_hack=1) # # Shape: [1, batch_size, 4 * hidden_size] gates_t_rs, _ = model.net.Reshape( gates_t, [self.scope('gates_t_rs'), self.scope('gates_t_old_shape')], shape=[1, -1, 4 * self.decoder_state_dim], ) hidden_t_intermediate, cell_t = model.net.LSTMUnit( [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep], [self.scope('hidden_t_intermediate'), self.scope('cell_t')], ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, self.recompute_blobs, ) = (apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev), )) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, self.recompute_blobs, ) = (apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, )) hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) model.net.AddExternalOutputs( cell_t, hidden_t, attention_weighted_encoder_context_t, ) return hidden_t, cell_t, attention_weighted_encoder_context_t
def _apply( self, model, input_t, seq_lengths, states, timestep, ): ( hidden_t_prev, cell_t_prev, attention_weighted_encoder_context_t_prev, ) = states gates_concatenated_input_t, _ = model.net.Concat( [hidden_t_prev, attention_weighted_encoder_context_t_prev], [ self.scope('gates_concatenated_input_t'), self.scope('_gates_concatenated_input_t_concat_dims'), ], axis=2, ) gates_t = model.FC( gates_concatenated_input_t, self.scope('gates_t'), dim_in=self.decoder_state_dim + self.encoder_output_dim, dim_out=4 * self.decoder_state_dim, axis=2, ) model.net.Sum([gates_t, input_t], gates_t) hidden_t_intermediate, cell_t = model.net.LSTMUnit( [ hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep, ], ['hidden_t_intermediate', self.scope('cell_t')], ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev), ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, ) hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) model.net.AddExternalOutputs( cell_t, hidden_t, attention_weighted_encoder_context_t, ) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) if self.lstm_memory_optimization: self.recompute_blobs.append(gates_t) return hidden_t, cell_t, attention_weighted_encoder_context_t
def _apply( self, model, input_t, seq_lengths, states, timestep, extra_inputs=None, ): decoder_prev_states = states[:-1] attention_weighted_encoder_context_t_prev = states[-1] assert extra_inputs is None decoder_states = self.decoder_cell._apply( model, input_t, seq_lengths, decoder_prev_states, timestep, extra_inputs=[( attention_weighted_encoder_context_t_prev, self.encoder_output_dim, )], ) self.hidden_t_intermediate = self.decoder_cell._prepare_output( model, decoder_states, ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=self.hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev ), encoder_lengths=self.encoder_lengths, ) elif self.attention_type == AttentionType.Regular: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=self.hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, encoder_lengths=self.encoder_lengths, ) elif self.attention_type == AttentionType.Dot: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_dot_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, decoder_hidden_state_t=self.hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, encoder_lengths=self.encoder_lengths, ) else: raise Exception('Attention type {} not implemented'.format( self.attention_type )) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) output = list(decoder_states) + [attention_weighted_encoder_context_t] output[self.decoder_cell.get_output_state_index()] = model.Copy( output[self.decoder_cell.get_output_state_index()], self.scope('hidden_t_external'), ) model.net.AddExternalOutputs(*output) return output
def _apply( self, model, input_t, seq_lengths, states, timestep, extra_inputs=None, ): decoder_prev_states = states[:-1] attention_weighted_encoder_context_t_prev = states[-1] assert extra_inputs is None decoder_states = self.decoder_cell._apply( model, input_t, seq_lengths, decoder_prev_states, timestep, extra_inputs=[( attention_weighted_encoder_context_t_prev, self.encoder_output_dim, )], ) self.hidden_t_intermediate = self.decoder_cell._prepare_output( model, decoder_states, ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=self.hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev ), encoder_lengths=self.encoder_lengths, ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=self.hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, encoder_lengths=self.encoder_lengths, ) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) output = list(decoder_states) + [attention_weighted_encoder_context_t] output[self.decoder_cell.get_output_state_index()] = model.Copy( output[self.decoder_cell.get_output_state_index()], self.scope('hidden_t_external'), ) model.net.AddExternalOutputs(*output) return output