def _build_attention(self, encoder_outputs, encoder_sequence_length): """ Builds Attention part of the graph. Currently supports "bahdanau" and "luong" :param encoder_outputs: :param encoder_sequence_length: :return: """ with tf.variable_scope("Attention"): attention_depth = self.model_params['attention_layer_size'] if self.model_params['attention_type'] == 'bahdanau': bah_normalize = self.model_params['bahdanau_normalize'] if 'bahdanau_normalize' in self.model_params else False attention_mechanism = attention_wrapper.BahdanauAttention(num_units=attention_depth, memory=encoder_outputs, normalize = bah_normalize, memory_sequence_length=encoder_sequence_length, probability_fn=tf.nn.softmax) elif self.model_params['attention_type'] == 'luong': luong_scale = self.model_params['luong_scale'] if 'luong_scale' in self.model_params else False attention_mechanism = attention_wrapper.LuongAttention(num_units=attention_depth, memory=encoder_outputs, scale = luong_scale, memory_sequence_length=encoder_sequence_length, probability_fn=tf.nn.softmax) else: raise ValueError('Unknown Attention Type') return attention_mechanism
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length if self.use_beamsearch_decode: print ("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch( self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth)] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def testCustomizedAttention(self): batch_size = 2 max_time = 3 num_units = 2 memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], [[4., 4.], [5., 5.], [6., 6.]]]) memory_sequence_length = constant_op.constant([3, 2]) attention_mechanism = wrapper.BahdanauAttention(num_units, memory, memory_sequence_length) # Sets all returned values to be all ones. def _customized_attention(unused_attention_mechanism, unused_cell_output, unused_attention_state, unused_attention_layer): """Customized attention. Returns: attention: `Tensor` of shape [batch_size, num_units], attention output. alignments: `Tensor` of shape [batch_size, max_time], sigma value for each input memory (prob. function of input keys). next_attention_state: A `Tensor` representing the next state for the attention. """ attention = array_ops.ones([batch_size, num_units]) alignments = array_ops.ones([batch_size, max_time]) next_attention_state = alignments return attention, alignments, next_attention_state attention_cell = wrapper.AttentionWrapper( rnn_cell.LSTMCell(2), attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, alignment_history=(), attention_fn=_customized_attention, name='attention') self.assertEqual(num_units, attention_cell.output_size) initial_state = attention_cell.zero_state( batch_size=2, dtype=dtypes.float32) source_input_emb = array_ops.ones([2, 3, 2]) source_input_length = constant_op.constant([3, 2]) # 'state' is a tuple of # (cell_state, h, attention, alignments, alignment_history, attention_state) output, state = rnn.dynamic_rnn( attention_cell, inputs=source_input_emb, sequence_length=source_input_length, initial_state=initial_state, dtype=dtypes.float32) with self.session() as sess: sess.run(variables.global_variables_initializer()) output_value, state_value = sess.run([output, state], feed_dict={}) self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) self.assertAllClose( np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) self.assertAllClose( np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state)
def testBahdanauNormalizedDType(self): for dtype in [np.float16, np.float32, np.float64]: num_units = 128 encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256]) encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128]) decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) batch_size = 64 attention_mechanism = wrapper.BahdanauAttention( num_units=num_units, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, normalize=True, dtype=dtype, ) cell = rnn_cell.LSTMCell(num_units) cell = wrapper.AttentionWrapper(cell, attention_mechanism) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtype, batch_size=batch_size)) final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
def build_dec_cell(self, hidden_size): enc_outputs = self.enc_outputs enc_last_state = self.enc_last_state enc_inputs_length = self.enc_inp_len if self.use_beam_search: self.logger.info("using beam search decoding") enc_outputs = seq2seq.tile_batch(self.enc_outputs, multiplier=self.p.beam_width) enc_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.p.beam_width), self.enc_last_state) enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len, self.p.beam_width) if self.p.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) else: self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) def attn_dec_input_fn(inputs, attention): if not self.p.attn_input_feeding: return inputs else: _input_layer = Dense(hidden_size, dtype=self.p.dtype, name='attn_input_feeding') return _input_layer(tf.concat([inputs, attention], -1)) self.dec_cell_list = [ self.build_single_cell(hidden_size) for _ in range(self.p.depth) ] if self.p.use_attn: self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.dec_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=hidden_size, cell_input_fn=attn_dec_input_fn, initial_cell_state=enc_last_state[-1], alignment_history=False, name='attention_wrapper') batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width initial_state = [state for state in enc_last_state] if self.p.use_attn: initial_state[-1] = self.dec_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.p.dtype) dec_initial_state = tuple(initial_state) return MultiRNNCell(self.dec_cell_list), dec_initial_state
def create_decoder(self, encoded, inputs, speaker_embed, train=True): config = self.config attention_mech = wrapper.BahdanauAttention( config.attention_units, encoded, memory_sequence_length=inputs['text_length']) inner_cell = [GRUCell(config.decoder_units) for _ in range(3)] decoder_cell = OutputProjectionWrapper( InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)), config.decoder_units), config.mel_features * config.r) # feed in rth frame at each time step decoder_frame_input = \ lambda inputs, attention: tf.concat( [self.pre_net(tf.slice(inputs, [0, (config.r - 1)*config.mel_features], [-1, -1]), dropout=config.audio_dropout_prob, train=train), attention] , -1) cell = wrapper.AttentionWrapper( decoder_cell, attention_mech, attention_layer_size=config.attention_units, cell_input_fn=decoder_frame_input, alignment_history=True, output_attention=False) if train: if config.scheduled_sample: print("if train if config.scheduled_sample: %s" % str( (inputs['mel'], inputs['speech_length'], config.scheduled_sample))) decoder_helper = helper.ScheduledOutputTrainingHelper( inputs['mel'], inputs['speech_length'], config.scheduled_sample) else: decoder_helper = helper.TrainingHelper(inputs['mel'], inputs['speech_length']) else: decoder_helper = ops.InferenceHelper( tf.shape(inputs['text'])[0], config.mel_features * config.r) initial_state = cell.zero_state(dtype=tf.float32, batch_size=tf.shape(inputs['text'])[0]) #if speaker_embed is not None: #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units) dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state) return dec
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # building attention mechanism: default Bahdanau # 'Bahdanau': https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # 'Luong': https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_size, memory=self.encoder_outputs, memory_sequence_length=self.encoder_inputs_length) # building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for _ in range(self.layer_num) ] def att_decoder_input_fn(inputs, attention): if not self.use_att_decoding: return inputs _input_layer = Dense(self.hidden_size, dtype=self.dtype, name='att_input_feeding') return _input_layer(array_ops.concat([inputs, attention], axis=-1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # implement attention mechanism only on the top of decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_size, cell_input_fn=att_decoder_input_fn, initial_cell_state=encoder_last_state[ -1], # last hidden state of last encode layer alignment_history=False, name='Attention_Wrapper') initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=self.batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def build_decoder_cell(self, num_units, num_layers, keep_prob): encoder_outputs = tf.concat(self.encoder_outputs, axis=-1) encoder_final_state = [] encoder_fw_fs, encoder_bw_fs = self.encoder_fs for i in range(num_layers): final_state_c = tf.concat((encoder_fw_fs[i].c, encoder_bw_fs[i].c), axis=1) final_state_h = tf.concat((encoder_fw_fs[i].h, encoder_bw_fs[i].h), axis=1) encoder_final_state.append( LSTMStateTuple(c=final_state_c, h=final_state_h)) encoder_fs = tuple(encoder_final_state) # build decoder cell decoder_cells = [ self.make_rnn_cell(num_units, keep_prob) for _ in range(num_layers) ] attention_cell = decoder_cells.pop() # use Bahdanua attention to all cell layers. self.attention_machenism = attention_wrapper.BahdanauAttention( num_units=num_units, memory=encoder_outputs, normalize=False, memory_sequence_length=self.encoder_length) attention_cell = attention_wrapper.AttentionWrapper( attention_cell, self.attention_machenism, attention_layer_size=None, initial_cell_state=None, output_attention=False, alignment_history=False, ) decoder_cells.append(attention_cell) decoder_cells = tf.nn.rnn_cell.MultiRNNCell(decoder_cells) batch = self.batch decoder_init_state = tuple( zs.clone(cell_state=es) if isinstance( zs, tf.contrib.seq2seq.AttentionWrapperState) else es for zs, es in zip( decoder_cells.zero_state(batch, dtype=tf.float32), encoder_fs)) # why the last layers' zero state different with # init_state = [state for state in encoder_fs] return decoder_cells, decoder_init_state
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # Building Attention Mechanism: Default Bahdanau self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.decoder_hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth) ] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs _input_layer = Dense(self.decoder_hidden_units, dtype=tf.float32, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.decoder_hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') batch_size = self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=tf.float32) decoder_initial_state = tuple(initial_state) return tf.contrib.rnn.MultiRNNCell( self.decoder_cell_list), decoder_initial_state
def build_decoder_cell(self): self.decoder_cell_list = \ [self.build_single_cell() for i in range(self.para.num_layers)] if self.para.mode == 'train': encoder_outputs = self.encoder_outputs encoder_inputs_len = self.encoder_inputs_len encoder_states = self.encoder_states batch_size = self.para.batch_size else: encoder_outputs = seq2seq.tile_batch( self.encoder_outputs, multiplier=self.para.beam_width) encoder_inputs_len = seq2seq.tile_batch( self.encoder_inputs_len, multiplier=self.para.beam_width) encoder_states = seq2seq.tile_batch( self.encoder_states, multiplier=self.para.beam_width) batch_size = self.para.batch_size * self.para.beam_width if self.para.attention_mode == 'luong': # scaled luong: recommended by authors of NMT self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.para.num_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_len, scale=True) output_attention = True else: self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.para.num_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_len) output_attention = False cell = tf.contrib.rnn.MultiRNNCell(self.decoder_cell_list) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=self.attention_mechanism, attention_layer_size=self.para.num_units, name='attention') decoder_initial_state = cell.zero_state( batch_size, self.dtype).clone(cell_state=encoder_states) return cell, decoder_initial_state
def build_attention_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) if self.attention_type.lower()=='luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # # Building decoder_cell self.decoder_cell_list = [self.build_single_cell() for i in range(self.num_layers)] def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = tf.layers.dense(tf.concat([inputs, attention], axis=-1), self.hidden_units, name='attn_input_feeding') return _input_layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper' ) batch_size = self.config['batch_size'] initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _testDynamicDecodeRNN(self, time_major, has_attention): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.test_session() as sess: batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=False) cell_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, start_tokens=array_ops.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0) final_outputs, final_state, final_sequence_lengths = ( decoder.dynamic_decode(bsd, output_time_major=time_major, maximum_iterations=max_out)) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertTrue( isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)) self.assertTrue( isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( _t((batch_size, None, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, None, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( sess_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..] if self.use_beamsearch_decode: print("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth) ] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=True, name='Attention_Wrapper') # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state # Also if beamsearch decoding is used, the batch_size argument in .zero_state # should be ${decoder_beam_width} times to the origianl batch_size batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _attention_decoder_wrapper(batch_size, num_units, memory, mutli_layer, dtype=dtypes.float32 ,\ attention_layer_size=None, cell_input_fn=None, attention_type='B',\ probability_fn=None, alignment_history=False, output_attention=True, \ initial_cell_state=None, normalization=False, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0.): """ A wrapper for rnn-decoder with attention mechanism the detail about params explanation can be found at : blog.csdn.net/qsczse943062710/article/details/79539005 :param mutli_layer: a object returned by function _mutli_layer_rnn() :param attention_type, string 'B' is for BahdanauAttention as described in: Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. https://arxiv.org/abs/1409.0473 'L' is for LuongAttention as described in: Minh-Thang Luong, Hieu Pham, Christopher D. Manning. "Effective Approaches to Attention-based Neural Machine Translation." EMNLP 2015. https://arxiv.org/abs/1508.04025 MonotonicAttention is described in : Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, "Online and Linear-Time Attention by Enforcing Monotonic Alignments." ICML 2017. https://arxiv.org/abs/1704.00784 'BM' : Monotonic attention mechanism with Bahadanau-style energy function 'LM' : Monotonic attention mechanism with Luong-style energy function or maybe something user defined in the future **warning** : if normalization is set True, then normalization will be applied to all types of attentions as described in: Tim Salimans, Diederik P. Kingma. "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 A example usage: att_wrapper, states = _attention_decoder_wrapper(*args) while decoding: output, states = att_wrapper(input, states) ... some processing on output ... input = processed_output """ if attention_type == 'B': attention_mechanism = att_w.BahdanauAttention( num_units=num_units, memory=memory, probability_fn=probability_fn, normalize=normalization) elif attention_type == 'BM': attention_mechanism = att_w.BahdanauMonotonicAttention( num_units=num_units, memory=memory, normalize=normalization, sigmoid_noise=sigmoid_noise, sigmoid_noise_seed=sigmoid_noise_seed, score_bias_init=score_bias_init) elif attention_type == 'L': attention_mechanism = att_w.LuongAttention( num_units=num_units, memory=memory, probability_fn=probability_fn, scale=normalization) elif attention_type == 'LM': attention_mechanism = att_w.LuongMonotonicAttention( num_units=num_units, memory=memory, scale=normalization, sigmoid_noise=sigmoid_noise, sigmoid_noise_seed=sigmoid_noise_seed, score_bias_init=score_bias_init) else: raise 'Invalid attention type' att_wrapper = att_w.AttentionWrapper( cell=mutli_layer, attention_mechanism=attention_mechanism, attention_layer_size=attention_layer_size, cell_input_fn=cell_input_fn, alignment_history=alignment_history, output_attention=output_attention, initial_cell_state=initial_cell_state) init_states = att_wrapper.zero_state(batch_size=batch_size, dtype=dtype) return att_wrapper, init_states
from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.python.ops import rnn_cell # # tf.enable_eager_execution() batch_size = 5 src_len = [4, 5, 3, 5, 6] max_times = 6 num_units = 16 enc_output = tf.random.normal((batch_size, max_times, num_units), dtype=tf.float32) # # attenRNNCell rnncell = rnn_cell.LSTMCell(num_units=16) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=num_units, memory=enc_output, memory_sequence_length=src_len) attnRNNCell = attention_wrapper.AttentionWrapper( cell=rnncell, attention_mechanism=attention_mechanism, alignment_history=True) # training tgt_len = [5, 6, 2, 7, 4] tgt_max_times = 7 tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units), dtype=tf.float32) training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len) # train helper train_decoder = basic_decoder.BasicDecoder( cell=attnRNNCell,