def create_attention_mechanism(num_units, memory, source_sequence_length, dtype=None): """Create attention mechanism based on the attention_option.""" # Mechanism attention_mechanism = attention_wrapper.BahdanauAttention( num_units, memory, memory_sequence_length=tf.to_int64(source_sequence_length), normalize=True, dtype=dtype) return attention_mechanism
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length use_beamsearch_decode = (self.mode.lower() == 'decode' and self.beam_width > 1) # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..] if self.use_beamsearch_decode: print("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth) ] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state # Also if beamsearch decoding is used, the batch_size argument in .zero_state # should be ${decoder_beam_width} times to the origianl batch_size batch_size = self.batch_size * self.beam_width if use_beamsearch_decode \ else self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state