def _build_decoder_cell(self): # no beam encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs _input_layer = Dense(self.hidden_units, dtype=self.dtype, name="attn_input_feeding") return _input_layer(array_ops.concat([inputs, attention], -1)) # attention mechanism 'luong' with tf.variable_scope('shared_attention_mechanism'): self.attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \ memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # build decoder cell self.init_decoder_cell_list = [self._build_single_cell() for i in range(self.depth)] decoder_initial_state = encoder_last_state self.decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\ cell = self.init_decoder_cell_list[-1], \ attention_mechanism=self.attention_mechanism,\ attention_layer_size=self.hidden_units,\ cell_input_fn=attn_decoder_input_fn,\ initial_cell_state=encoder_last_state[-1],\ alignment_history=False)] batch_size = self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state(batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) # beam beam_encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) beam_encoder_last_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) beam_encoder_inputs_length = seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_width) with tf.variable_scope('shared_attention_mechanism', reuse=True): self.beam_attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \ memory=beam_encoder_outputs, \ memory_sequence_length=beam_encoder_inputs_length) beam_decoder_initial_state = beam_encoder_last_state self.beam_decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\ cell = self.init_decoder_cell_list[-1], \ attention_mechanism=self.beam_attention_mechanism,\ attention_layer_size=self.hidden_units,\ cell_input_fn=attn_decoder_input_fn,\ initial_cell_state=beam_encoder_last_state[-1],\ alignment_history=False)] beam_batch_size = self.batch_size * self.beam_width beam_initial_state = [state for state in beam_encoder_last_state] beam_initial_state[-1] = self.beam_decoder_cell_list[-1].zero_state(batch_size=beam_batch_size, dtype=self.dtype) beam_decoder_initial_state = tuple(beam_initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state, \ MultiRNNCell(self.beam_decoder_cell_list), beam_decoder_initial_state
def testLuongScaledDType(self): # Test case for GitHub issue 18099 for dtype in [np.float16, np.float32, np.float64]: num_units = 128 encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256]) encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128]) decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) batch_size = 64 attention_mechanism = wrapper.LuongAttention( num_units=num_units, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, scale=True, dtype=dtype, ) cell = rnn_cell.LSTMCell(num_units) cell = wrapper.AttentionWrapper(cell, attention_mechanism) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtype, batch_size=batch_size)) final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
def testCustomizedAttention(self): batch_size = 2 max_time = 3 num_units = 2 memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], [[4., 4.], [5., 5.], [6., 6.]]]) memory_sequence_length = constant_op.constant([3, 2]) attention_mechanism = wrapper.BahdanauAttention(num_units, memory, memory_sequence_length) # Sets all returned values to be all ones. def _customized_attention(unused_attention_mechanism, unused_cell_output, unused_attention_state, unused_attention_layer): """Customized attention. Returns: attention: `Tensor` of shape [batch_size, num_units], attention output. alignments: `Tensor` of shape [batch_size, max_time], sigma value for each input memory (prob. function of input keys). next_attention_state: A `Tensor` representing the next state for the attention. """ attention = array_ops.ones([batch_size, num_units]) alignments = array_ops.ones([batch_size, max_time]) next_attention_state = alignments return attention, alignments, next_attention_state attention_cell = wrapper.AttentionWrapper( rnn_cell.LSTMCell(2), attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, alignment_history=(), attention_fn=_customized_attention, name='attention') self.assertEqual(num_units, attention_cell.output_size) initial_state = attention_cell.zero_state( batch_size=2, dtype=dtypes.float32) source_input_emb = array_ops.ones([2, 3, 2]) source_input_length = constant_op.constant([3, 2]) # 'state' is a tuple of # (cell_state, h, attention, alignments, alignment_history, attention_state) output, state = rnn.dynamic_rnn( attention_cell, inputs=source_input_emb, sequence_length=source_input_length, initial_state=initial_state, dtype=dtypes.float32) with self.session() as sess: sess.run(variables.global_variables_initializer()) output_value, state_value = sess.run([output, state], feed_dict={}) self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) self.assertAllClose( np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) self.assertAllClose( np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state)
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length if self.use_beamsearch_decode: print ("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch( self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth)] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def testLuongScaledDType(self, dtype): # Test case for GitHub issue 18099 encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.LuongAttentionV2( units=self.units, memory=encoder_outputs, memory_sequence_length=self.encoder_sequence_length, scale=True, dtype=dtype, ) cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), sequence_length=self.decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
def build_dec_cell(self, hidden_size): enc_outputs = self.enc_outputs enc_last_state = self.enc_last_state enc_inputs_length = self.enc_inp_len if self.use_beam_search: self.logger.info("using beam search decoding") enc_outputs = seq2seq.tile_batch(self.enc_outputs, multiplier=self.p.beam_width) enc_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.p.beam_width), self.enc_last_state) enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len, self.p.beam_width) if self.p.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) else: self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) def attn_dec_input_fn(inputs, attention): if not self.p.attn_input_feeding: return inputs else: _input_layer = Dense(hidden_size, dtype=self.p.dtype, name='attn_input_feeding') return _input_layer(tf.concat([inputs, attention], -1)) self.dec_cell_list = [ self.build_single_cell(hidden_size) for _ in range(self.p.depth) ] if self.p.use_attn: self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.dec_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=hidden_size, cell_input_fn=attn_dec_input_fn, initial_cell_state=enc_last_state[-1], alignment_history=False, name='attention_wrapper') batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width initial_state = [state for state in enc_last_state] if self.p.use_attn: initial_state[-1] = self.dec_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.p.dtype) dec_initial_state = tuple(initial_state) return MultiRNNCell(self.dec_cell_list), dec_initial_state
def create_decoder(self, encoded, inputs, speaker_embed, train=True): config = self.config attention_mech = wrapper.BahdanauAttention( config.attention_units, encoded, memory_sequence_length=inputs['text_length']) inner_cell = [GRUCell(config.decoder_units) for _ in range(3)] decoder_cell = OutputProjectionWrapper( InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)), config.decoder_units), config.mel_features * config.r) # feed in rth frame at each time step decoder_frame_input = \ lambda inputs, attention: tf.concat( [self.pre_net(tf.slice(inputs, [0, (config.r - 1)*config.mel_features], [-1, -1]), dropout=config.audio_dropout_prob, train=train), attention] , -1) cell = wrapper.AttentionWrapper( decoder_cell, attention_mech, attention_layer_size=config.attention_units, cell_input_fn=decoder_frame_input, alignment_history=True, output_attention=False) if train: if config.scheduled_sample: print("if train if config.scheduled_sample: %s" % str( (inputs['mel'], inputs['speech_length'], config.scheduled_sample))) decoder_helper = helper.ScheduledOutputTrainingHelper( inputs['mel'], inputs['speech_length'], config.scheduled_sample) else: decoder_helper = helper.TrainingHelper(inputs['mel'], inputs['speech_length']) else: decoder_helper = ops.InferenceHelper( tf.shape(inputs['text'])[0], config.mel_features * config.r) initial_state = cell.zero_state(dtype=tf.float32, batch_size=tf.shape(inputs['text'])[0]) #if speaker_embed is not None: #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units) dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state) return dec
def _attentive_bidirectional_cudnn_LSTM(self, inputs, input_size, lengths, attention_mechanism=False, num_units=256, num_layers=1, dp_input_keep_prob=1.0, dp_output_keep_prob=1.0): with tf.variable_scope('fw'): cell_fw = create_cudnn_LSTM_cell( num_units=num_units, input_size=input_size, num_layers=num_layers, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob) with tf.variable_scope('bw'): cell_bw = create_cudnn_LSTM_cell( num_units=num_units, input_size=input_size, num_layers=num_layers, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob) if attention_mechanism: cell_fw = attention_wrapper.AttentionWrapper( cell=cell_fw, attention_mechanism=attention_mechanism, output_attention=False) cell_bw = attention_wrapper.AttentionWrapper( cell=cell_bw, attention_mechanism=attention_mechanism, output_attention=False) return bidirectional_dynamic_rnn(cell_fw=cell_fw, cell_bw=cell_bw, inputs=inputs, sequence_length=lengths, dtype=getdtype())
def build_decoder(self): """ decoder :return: """ print('build decoder with attention...') with tf.variable_scope('decoder'): self.decoder_embeddings = tf.Variable( tf.random_uniform([self.vocab_size, self.embedding_size])) # 2.1 add attention def build_decoder_cell(): decoder_cell = tf.contrib.rnn.LSTMCell( self.hidden_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) return decoder_cell attention_states = self.encoder_outputs attention_mechanism = tf.contrib.seq2seq.LuongAttention( num_units=self.hidden_size, memory=attention_states, memory_sequence_length=self.source_sequence_length) decoder_cells_list = [ build_decoder_cell() for _ in range(self.num_layers) ] decoder_cells_list[-1] = attention_wrapper.AttentionWrapper( cell=decoder_cells_list[-1], attention_mechanism=attention_mechanism, attention_layer_size=self.hidden_size) self.decoder_cells = tf.contrib.rnn.MultiRNNCell( decoder_cells_list) initial_state = [state for state in self.encoder_states] initial_state[-1] = decoder_cells_list[-1].zero_state( batch_size=self.batch_size, dtype=tf.float32) self.decoder_initial_state = tuple(initial_state) # 全连接 self.output_layer = Dense( self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) if self.mode == 'train': self.interfer() elif self.mode == 'decode': self.decode()
def _attentive_bidirectional_rnn(self, inputs, lengths, num_units, attention_mechanism=False, cell_type='gru', num_layers=1, dp_input_keep_prob=1.0, dp_output_keep_prob=1.0): cell_fw = create_rnn_cell(cell_type=cell_type, num_units=num_units, num_layers=num_layers, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob) cell_bw = create_rnn_cell(cell_type=cell_type, num_units=num_units, num_layers=num_layers, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob) if attention_mechanism: cell_fw = attention_wrapper.AttentionWrapper( cell=cell_fw, attention_mechanism=attention_mechanism, output_attention=False) cell_bw = attention_wrapper.AttentionWrapper( cell=cell_bw, attention_mechanism=attention_mechanism, output_attention=False) return bidirectional_dynamic_rnn(cell_fw=cell_fw, cell_bw=cell_bw, inputs=inputs, sequence_length=lengths, dtype=getdtype())
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # building attention mechanism: default Bahdanau # 'Bahdanau': https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # 'Luong': https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_size, memory=self.encoder_outputs, memory_sequence_length=self.encoder_inputs_length) # building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for _ in range(self.layer_num) ] def att_decoder_input_fn(inputs, attention): if not self.use_att_decoding: return inputs _input_layer = Dense(self.hidden_size, dtype=self.dtype, name='att_input_feeding') return _input_layer(array_ops.concat([inputs, attention], axis=-1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # implement attention mechanism only on the top of decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_size, cell_input_fn=att_decoder_input_fn, initial_cell_state=encoder_last_state[ -1], # last hidden state of last encode layer alignment_history=False, name='Attention_Wrapper') initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=self.batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def build_decoder_cell(self, num_units, num_layers, keep_prob): encoder_outputs = tf.concat(self.encoder_outputs, axis=-1) encoder_final_state = [] encoder_fw_fs, encoder_bw_fs = self.encoder_fs for i in range(num_layers): final_state_c = tf.concat((encoder_fw_fs[i].c, encoder_bw_fs[i].c), axis=1) final_state_h = tf.concat((encoder_fw_fs[i].h, encoder_bw_fs[i].h), axis=1) encoder_final_state.append( LSTMStateTuple(c=final_state_c, h=final_state_h)) encoder_fs = tuple(encoder_final_state) # build decoder cell decoder_cells = [ self.make_rnn_cell(num_units, keep_prob) for _ in range(num_layers) ] attention_cell = decoder_cells.pop() # use Bahdanua attention to all cell layers. self.attention_machenism = attention_wrapper.BahdanauAttention( num_units=num_units, memory=encoder_outputs, normalize=False, memory_sequence_length=self.encoder_length) attention_cell = attention_wrapper.AttentionWrapper( attention_cell, self.attention_machenism, attention_layer_size=None, initial_cell_state=None, output_attention=False, alignment_history=False, ) decoder_cells.append(attention_cell) decoder_cells = tf.nn.rnn_cell.MultiRNNCell(decoder_cells) batch = self.batch decoder_init_state = tuple( zs.clone(cell_state=es) if isinstance( zs, tf.contrib.seq2seq.AttentionWrapperState) else es for zs, es in zip( decoder_cells.zero_state(batch, dtype=tf.float32), encoder_fs)) # why the last layers' zero state different with # init_state = [state for state in encoder_fs] return decoder_cells, decoder_init_state
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # Building Attention Mechanism: Default Bahdanau self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.decoder_hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth) ] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs _input_layer = Dense(self.decoder_hidden_units, dtype=tf.float32, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.decoder_hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') batch_size = self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=tf.float32) decoder_initial_state = tuple(initial_state) return tf.contrib.rnn.MultiRNNCell( self.decoder_cell_list), decoder_initial_state
def build_decoder_cell(self): self.decoder_cell_list = \ [self.build_single_cell() for i in range(self.para.num_layers)] if self.para.mode == 'train': encoder_outputs = self.encoder_outputs encoder_inputs_len = self.encoder_inputs_len encoder_states = self.encoder_states batch_size = self.para.batch_size else: encoder_outputs = seq2seq.tile_batch( self.encoder_outputs, multiplier=self.para.beam_width) encoder_inputs_len = seq2seq.tile_batch( self.encoder_inputs_len, multiplier=self.para.beam_width) encoder_states = seq2seq.tile_batch( self.encoder_states, multiplier=self.para.beam_width) batch_size = self.para.batch_size * self.para.beam_width if self.para.attention_mode == 'luong': # scaled luong: recommended by authors of NMT self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.para.num_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_len, scale=True) output_attention = True else: self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.para.num_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_len) output_attention = False cell = tf.contrib.rnn.MultiRNNCell(self.decoder_cell_list) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=self.attention_mechanism, attention_layer_size=self.para.num_units, name='attention') decoder_initial_state = cell.zero_state( batch_size, self.dtype).clone(cell_state=encoder_states) return cell, decoder_initial_state
def build_attention_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) if self.attention_type.lower()=='luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # # Building decoder_cell self.decoder_cell_list = [self.build_single_cell() for i in range(self.num_layers)] def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = tf.layers.dense(tf.concat([inputs, attention], axis=-1), self.hidden_units, name='attn_input_feeding') return _input_layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper' ) batch_size = self.config['batch_size'] initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def testAttentionWrapperStateShapePropgation(self): batch_size = 5 max_time = 5 num_units = 5 memory = random_ops.random_uniform( [batch_size, max_time, num_units], seed=1) mechanism = wrapper.LuongAttention(num_units, memory) cell = wrapper.AttentionWrapper(rnn_cell.LSTMCell(num_units), mechanism) # Create zero state with static batch size. static_state = cell.zero_state(batch_size, dtypes.float32) # Create zero state without static batch size. state = cell.zero_state(array_ops.shape(memory)[0], dtypes.float32) state = static_state.clone( cell_state=state.cell_state, attention=state.attention) self.assertEqual(state.cell_state.c.shape, static_state.cell_state.c.shape) self.assertEqual(state.cell_state.h.shape, static_state.cell_state.h.shape) self.assertEqual(state.attention.shape, static_state.attention.shape)
def _testBahdanauNormalizedDType(self, dtype): encoder_outputs = self.encoder_outputs.astype(dtype) decoder_inputs = self.decoder_inputs.astype(dtype) attention_mechanism = wrapper.BahdanauAttentionV2( units=self.units, memory=encoder_outputs, memory_sequence_length=self.encoder_sequence_length, normalize=True, dtype=dtype) cell = rnn_cell.LSTMCell(self.units) cell = wrapper.AttentionWrapper(cell, attention_mechanism) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), sequence_length=self.decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertEqual(final_outputs.rnn_output.dtype, dtype) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
def _testWithMaybeMultiAttention(self, is_multi, create_attention_mechanisms, expected_final_output, expected_final_state, attention_mechanism_depths, alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, attention_layers=None, create_query_layer=False, create_memory_layer=True, create_attention_kwargs=None): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 create_attention_kwargs = create_attention_kwargs or {} if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. attention_depth = sum( attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( attention_layer.compute_output_shape( [batch_size, cell_depth + encoder_output_depth]).dims[-1].value for attention_layer in attention_layers) else: attention_depth = encoder_output_depth * len( create_attention_mechanisms) decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype( np.float32) attention_mechanisms = [] for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths): # Create a memory layer with deterministic initializer to avoid randomness # in the test between graph and eager. if create_query_layer: create_attention_kwargs["query_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) if create_memory_layer: create_attention_kwargs["memory_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) attention_mechanisms.append( creator(units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, **create_attention_kwargs)) with self.cached_session(use_gpu=True): attention_layer_size = attention_layer_sizes attention_layer = attention_layers if not is_multi: if attention_layer_size is not None: attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] cell = keras.layers.LSTMCell(cell_depth, recurrent_activation="sigmoid", kernel_initializer="ones", recurrent_initializer="ones") cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer) if cell._attention_layers is not None: for layer in cell._attention_layers: if getattr(layer, "kernel_initializer") is None: layer.kernel_initializer = initializers.glorot_uniform( seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) initial_state = cell.get_initial_state(dtype=dtypes.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=initial_state, sequence_length=decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) expected_time = (expected_final_state.time if context.executing_eagerly() else None) self.assertEqual( (batch_size, expected_time, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( (batch_size, expected_time), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual( (batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[0].get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state[1].get_shape().as_list())) if alignment_history: if is_multi: state_alignment_history = [] for history_array in final_state.alignment_history: history = history_array.stack() self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(history.get_shape().as_list())) state_alignment_history.append(history) state_alignment_history = tuple(state_alignment_history) else: state_alignment_history = final_state.alignment_history.stack( ) self.assertEqual( (expected_time, batch_size, encoder_max_time), tuple(state_alignment_history.get_shape().as_list())) nest.assert_same_structure( cell.state_size, cell.zero_state(batch_size, dtypes.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () self.evaluate(variables.global_variables_initializer()) eval_result = self.evaluate({ "final_outputs": final_outputs, "final_state": final_state, "state_alignment_history": state_alignment_history, }) final_output_info = nest.map_structure( get_result_summary, eval_result["final_outputs"]) final_state_info = nest.map_structure(get_result_summary, eval_result["final_state"]) print("final_output_info: ", final_output_info) print("final_state_info: ", final_state_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) if alignment_history: # by default, the wrapper emits attention as output final_alignment_history_info = nest.map_structure( get_result_summary, eval_result["state_alignment_history"]) print("final_alignment_history_info: ", final_alignment_history_info) nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major expected_final_alignment_history, final_alignment_history_info)
def _testWithAttention(self, create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=3, alignment_history=False, expected_final_alignment_history=None, attention_layer_size=6, name=''): encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 if attention_layer_size is not None: attention_depth = attention_layer_size else: attention_depth = encoder_output_depth decoder_inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) encoder_outputs = array_ops.placeholder_with_default( np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype(np.float32), shape=(None, None, encoder_output_depth)) attention_mechanism = create_attention_mechanism( num_units=attention_mechanism_depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) with self.test_session(use_gpu=True) as sess: with vs.variable_scope( 'root', initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( cell, attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=alignment_history) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) final_outputs, final_state, _ = decoder.dynamic_decode( my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple)) self.assertEqual( (batch_size, None, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( (batch_size, None), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual( (batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state.c.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state.h.get_shape().as_list())) if alignment_history: state_alignment_history = final_state.alignment_history.stack() # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access self.assertEqual( (None, batch_size, None), tuple(state_alignment_history.get_shape().as_list())) else: state_alignment_history = () sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'state_alignment_history': state_alignment_history, }) final_output_info = nest.map_structure( get_result_summary, sess_results['final_outputs']) final_state_info = nest.map_structure(get_result_summary, sess_results['final_state']) print(name) print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info)) print('expected_final_state = %s' % str(final_state_info)) nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) if alignment_history: # by default, the wrapper emits attention as output final_alignment_history_info = nest.map_structure( get_result_summary, sess_results['state_alignment_history']) print('expected_final_alignment_history = %s' % str(final_alignment_history_info)) nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major expected_final_alignment_history, final_alignment_history_info)
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..] if self.use_beamsearch_decode: print("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth) ] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=True, name='Attention_Wrapper') # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state # Also if beamsearch decoding is used, the batch_size argument in .zero_state # should be ${decoder_beam_width} times to the origianl batch_size batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _build_output_layer_context(self, ques_outputs, ques_sequence_length, ctx_outputs, ctx_sequence_length): with tf.variable_scope('Output'): attention_depth_ques = self._model_params[ 'output_attention_layer_size_ques'] if self._attention_type_output_layer == 'Luong': attention_depth_ques = self._model_params['ques_param_size'] attention_mechanism_ques = self._build_attention( num_units=attention_depth_ques, memory=ques_outputs, memory_sequence_length=ques_sequence_length, attention_type=self._attention_type_output_layer) self.V_ques = tf.get_variable( name='V_ques', shape=[1, self._model_params['ques_param_size']], dtype=getdtype()) alignments_ques = attention_mechanism_ques( tf.tile(self.V_ques, [self._model_params['batch_size'], 1]), previous_alignments=None) expanded_alignments_ques = tf.expand_dims(alignments_ques, 1) context = tf.matmul(expanded_alignments_ques, ques_outputs) context = tf.squeeze(context, [1]) attention_depth_ans = self._model_params[ 'output_attention_layer_size_ans'] if self._attention_type_output_layer == 'Luong': attention_depth_ans = self.embedded_dim attention_mechanism_ans = self._build_attention( num_units=attention_depth_ans, memory=ctx_outputs, memory_sequence_length=ctx_sequence_length, attention_type=self._attention_type_output_layer) output_rnn_cell = create_rnn_cell( cell_type=self._model_params['output_cell_type'], num_units=self.embedded_dim, num_layers=self._model_params['output_layers'], dp_input_keep_prob=self. _model_params['output_dp_input_keep_prob'], dp_output_keep_prob=self. _model_params['output_dp_output_keep_prob']) # context as initial state if self._model_params['output_cell_type'] == 'gru': if self._model_params['output_layers'] == 1: initial_state = context else: initial_state = tuple( context for _ in range(self._model_params['output_layers'])) elif self._model_params['output_cell_type'] == 'lstm': if self._model_params['output_layers'] == 1: initial_state = rnn_cell_impl.LSTMStateTuple( tf.zeros_like(context, dtype=getdtype()), context) else: initial_state = tuple( rnn_cell_impl.LSTMStateTuple( tf.zeros_like(context, dtype=getdtype()), context) for _ in range(self._model_params['output_layers'])) attentive_output_cell = attention_wrapper.AttentionWrapper( cell=output_rnn_cell, attention_mechanism=attention_mechanism_ans, alignment_history=True, cell_input_fn=lambda _, attention: attention, initial_cell_state=initial_state, output_attention=False) final_outputs, final_state = tf.nn.static_rnn( cell=attentive_output_cell, inputs=[ tf.zeros([self._model_params['batch_size'], 1]), tf.zeros([self._model_params['batch_size'], 1]) ], dtype=getdtype()) alignment_history = final_state.alignment_history return alignment_history
def _attention_decoder_wrapper(batch_size, num_units, memory, mutli_layer, dtype=dtypes.float32 ,\ attention_layer_size=None, cell_input_fn=None, attention_type='B',\ probability_fn=None, alignment_history=False, output_attention=True, \ initial_cell_state=None, normalization=False, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0.): """ A wrapper for rnn-decoder with attention mechanism the detail about params explanation can be found at : blog.csdn.net/qsczse943062710/article/details/79539005 :param mutli_layer: a object returned by function _mutli_layer_rnn() :param attention_type, string 'B' is for BahdanauAttention as described in: Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. https://arxiv.org/abs/1409.0473 'L' is for LuongAttention as described in: Minh-Thang Luong, Hieu Pham, Christopher D. Manning. "Effective Approaches to Attention-based Neural Machine Translation." EMNLP 2015. https://arxiv.org/abs/1508.04025 MonotonicAttention is described in : Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, "Online and Linear-Time Attention by Enforcing Monotonic Alignments." ICML 2017. https://arxiv.org/abs/1704.00784 'BM' : Monotonic attention mechanism with Bahadanau-style energy function 'LM' : Monotonic attention mechanism with Luong-style energy function or maybe something user defined in the future **warning** : if normalization is set True, then normalization will be applied to all types of attentions as described in: Tim Salimans, Diederik P. Kingma. "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 A example usage: att_wrapper, states = _attention_decoder_wrapper(*args) while decoding: output, states = att_wrapper(input, states) ... some processing on output ... input = processed_output """ if attention_type == 'B': attention_mechanism = att_w.BahdanauAttention( num_units=num_units, memory=memory, probability_fn=probability_fn, normalize=normalization) elif attention_type == 'BM': attention_mechanism = att_w.BahdanauMonotonicAttention( num_units=num_units, memory=memory, normalize=normalization, sigmoid_noise=sigmoid_noise, sigmoid_noise_seed=sigmoid_noise_seed, score_bias_init=score_bias_init) elif attention_type == 'L': attention_mechanism = att_w.LuongAttention( num_units=num_units, memory=memory, probability_fn=probability_fn, scale=normalization) elif attention_type == 'LM': attention_mechanism = att_w.LuongMonotonicAttention( num_units=num_units, memory=memory, scale=normalization, sigmoid_noise=sigmoid_noise, sigmoid_noise_seed=sigmoid_noise_seed, score_bias_init=score_bias_init) else: raise 'Invalid attention type' att_wrapper = att_w.AttentionWrapper( cell=mutli_layer, attention_mechanism=attention_mechanism, attention_layer_size=attention_layer_size, cell_input_fn=cell_input_fn, alignment_history=alignment_history, output_attention=output_attention, initial_cell_state=initial_cell_state) init_states = att_wrapper.zero_state(batch_size=batch_size, dtype=dtype) return att_wrapper, init_states
def _testWithAttention(self, create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=3, alignment_history=False, expected_final_alignment_history=None, name=""): encoder_sequence_length = [3, 2, 3, 1, 0] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 attention_depth = 6 decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype( np.float32) attention_mechanism = create_attention_mechanism( num_units=attention_mechanism_depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) with self.test_session(use_gpu=True) as sess: with vs.variable_scope( "root", initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = core_rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( cell, attention_mechanism, attention_size=attention_depth, alignment_history=alignment_history) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) final_outputs, final_state = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple)) self.assertEqual( (batch_size, None, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( (batch_size, None), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual( (batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state.c.get_shape().as_list())) self.assertEqual( (batch_size, cell_depth), tuple(final_state.cell_state.h.get_shape().as_list())) if alignment_history: state_alignment_history = final_state.alignment_history.stack() # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access self.assertEqual( (None, batch_size, encoder_max_time), tuple(state_alignment_history.get_shape().as_list())) else: state_alignment_history = () sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "final_outputs": final_outputs, "final_state": final_state, "state_alignment_history": state_alignment_history, }) print("Copy/paste (%s)\nexpected_final_output = " % name, sess_results["final_outputs"]) sys.stdout.flush() print("Copy/paste (%s)\nexpected_final_state = " % name, sess_results["final_state"]) sys.stdout.flush() print( "Copy/paste (%s)\nexpected_final_alignment_history = " % name, sess_results["state_alignment_history"]) sys.stdout.flush() nest.map_structure(self.assertAllClose, expected_final_output, sess_results["final_outputs"]) nest.map_structure(self.assertAllClose, expected_final_state, sess_results["final_state"]) if alignment_history: # by default, the wrapper emits attention as output self.assertAllClose( # outputs are batch major but the stacked TensorArray is time major sess_results["state_alignment_history"], expected_final_alignment_history)
def _build_decoder(self, encoder_outputs, enc_src_lengths, tgt_inputs = None, tgt_lengths = None, GO_SYMBOL = 1, END_SYMBOL = 2, out_layer_activation = None): """ Builds decoder part of the graph, for training and inference TODO: add param tensor shapes :param encoder_outputs: :param enc_src_lengths: :param tgt_inputs: :param tgt_lengths: :param GO_SYMBOL: :param END_SYMBOL: :param out_layer_activation: :return: """ with tf.variable_scope("Decoder"): tgt_vocab_size = self.model_params['tgt_vocab_size'] tgt_emb_size = self.model_params['tgt_emb_size'] self._tgt_w = tf.get_variable(name='W_tgt_embedding', shape=[tgt_vocab_size, tgt_emb_size], dtype=getdtype()) batch_size = self.model_params['batch_size'] decoder_cell = create_rnn_cell(cell_type=self.model_params['decoder_cell_type'], cell_params={"num_units": self.model_params['decoder_cell_units']}, num_layers=self.model_params['decoder_layers'], dp_input_keep_prob=self.model_params['decoder_dp_input_keep_prob'] if self._mode == "train" else 1.0, dp_output_keep_prob=self.model_params['decoder_dp_output_keep_prob'] if self._mode == "train" else 1.0, residual_connections=self.model_params['decoder_use_skip_connections']) output_layer = layers_core.Dense(tgt_vocab_size, use_bias=False, activation = out_layer_activation) def attn_decoder_custom_fn(inputs, attention): # to make shapes equal for skip connections if self.model_params['decoder_use_skip_connections']: input_layer = layers_core.Dense(self.model_params['decoder_cell_units'], dtype=getdtype()) return input_layer(tf.concat([inputs, attention], -1)) else: return tf.concat([inputs, attention], -1) if self.mode == "infer": if self._decoder_type == "beam_search": self._length_penalty_weight = 1.0 if "length_penalty" not in self.model_params else self.model_params[ "length_penalty"] # beam_width of 1 should be same as argmax decoder self._beam_width = 1 if "beam_width" not in self.model_params else self.model_params["beam_width"] tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self._beam_width) tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(enc_src_lengths, multiplier=self._beam_width) attention_mechanism = self._build_attention(tiled_enc_outputs, tiled_enc_src_lengths) attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, cell_input_fn=attn_decoder_custom_fn) batch_size_tensor = tf.constant(batch_size) decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=attentive_decoder_cell, embedding=self._tgt_w, start_tokens=tf.tile([GO_SYMBOL], [batch_size]), end_token=END_SYMBOL, initial_state=attentive_decoder_cell.zero_state(dtype=getdtype(), batch_size=batch_size_tensor * self._beam_width), beam_width=self._beam_width, output_layer=output_layer, length_penalty_weight=self._length_penalty_weight) else: attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths) attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, cell_input_fn=attn_decoder_custom_fn) helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=self._tgt_w, start_tokens=tf.fill([batch_size], GO_SYMBOL), end_token=END_SYMBOL) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, initial_state=attentive_decoder_cell.zero_state(batch_size=batch_size, dtype=getdtype()), output_layer=output_layer) elif self.mode == "train" or self.mode == "eval": attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths) attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, cell_input_fn=attn_decoder_custom_fn) input_vectors = tf.nn.embedding_lookup(self._tgt_w, tgt_inputs) helper = tf.contrib.seq2seq.TrainingHelper( inputs = input_vectors, sequence_length = tgt_lengths) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, output_layer=output_layer, initial_state=attentive_decoder_cell.zero_state(batch_size, dtype=getdtype())) else: raise NotImplementedError("Unknown mode") final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder = decoder, impute_finished=False if self._decoder_type == "beam_search" else True, maximum_iterations=tf.reduce_max(tgt_lengths) if self._mode == 'train' else tf.reduce_max(enc_src_lengths)*2, swap_memory = False if 'use_swap_memory' not in self.model_params else self.model_params['use_swap_memory']) return final_outputs, final_state, final_sequence_lengths
def _testWithMaybeMultiAttention(self, is_multi, create_attention_mechanisms, expected_final_output, expected_final_state, attention_mechanism_depths, alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, name=''): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 if attention_layer_sizes is None: attention_depth = encoder_output_depth * len(create_attention_mechanisms) else: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. attention_depth = sum([attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes]) decoder_inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) encoder_outputs = array_ops.placeholder_with_default( np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype(np.float32), shape=(None, None, encoder_output_depth)) attention_mechanisms = [ creator(num_units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths)] with self.test_session(use_gpu=True) as sess: with vs.variable_scope( 'root', initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], attention_layer_size=(attention_layer_sizes if is_multi else attention_layer_sizes[0]), alignment_history=alignment_history) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple)) self.assertEqual((batch_size, None, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual((batch_size, None), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual((batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual((batch_size, cell_depth), tuple(final_state.cell_state.c.get_shape().as_list())) self.assertEqual((batch_size, cell_depth), tuple(final_state.cell_state.h.get_shape().as_list())) if alignment_history: if is_multi: state_alignment_history = [] for history_array in final_state.alignment_history: history = history_array.stack() self.assertEqual( (None, batch_size, None), tuple(history.get_shape().as_list())) state_alignment_history.append(history) state_alignment_history = tuple(state_alignment_history) else: state_alignment_history = final_state.alignment_history.stack() self.assertEqual( (None, batch_size, None), tuple(state_alignment_history.get_shape().as_list())) nest.assert_same_structure( cell.state_size, cell.zero_state(batch_size, dtypes.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'state_alignment_history': state_alignment_history, }) final_output_info = nest.map_structure(get_result_summary, sess_results['final_outputs']) final_state_info = nest.map_structure(get_result_summary, sess_results['final_state']) print(name) print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info)) print('expected_final_state = %s' % str(final_state_info)) nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) if alignment_history: # by default, the wrapper emits attention as output final_alignment_history_info = nest.map_structure( get_result_summary, sess_results['state_alignment_history']) print('expected_final_alignment_history = %s' % str(final_alignment_history_info)) nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major expected_final_alignment_history, final_alignment_history_info)
def _build_decoder(self, encoder_outputs, encoder_state): with tf.name_scope("seq_decoder"): batch_size = self.batch_size # sequence_length = tf.fill([self.batch_size], self.num_steps) if self.mode == tf.contrib.learn.ModeKeys.TRAIN: sequence_length = self.iterator.target_length else: sequence_length = self.iterator.source_length if (self.mode != tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1: batch_size = batch_size * self.beam_width encoder_outputs = beam_search_decoder.tile_batch( encoder_outputs, multiplier=self.beam_width) encoder_state = nest.map_structure( lambda s: beam_search_decoder.tile_batch( s, self.beam_width), encoder_state) sequence_length = beam_search_decoder.tile_batch( sequence_length, multiplier=self.beam_width) single_cell = single_rnn_cell(self.hparams.unit_type, self.num_units, self.dropout) decoder_cell = MultiRNNCell( [single_cell for _ in range(self.num_layers_decoder)]) decoder_cell = InputProjectionWrapper(decoder_cell, num_proj=self.num_units) attention_mechanism = create_attention_mechanism( self.hparams.attention_mechanism, self.num_units, memory=encoder_outputs, source_sequence_length=sequence_length) decoder_cell = wrapper.AttentionWrapper( decoder_cell, attention_mechanism, attention_layer_size=self.num_units, output_attention=True, alignment_history=False) # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다. initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32) embeddings_decoder = tf.get_variable( "embedding_decoder", [self.num_decoder_symbols, self.num_units], initializer=self.initializer, dtype=tf.float32) output_layer = Dense(units=self.num_decoder_symbols, use_bias=True, name="output_layer") if self.mode == tf.contrib.learn.ModeKeys.TRAIN: decoder_inputs = tf.nn.embedding_lookup( embeddings_decoder, self.iterator.target_in) decoder_helper = helper.TrainingHelper( decoder_inputs, sequence_length=sequence_length) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) final_outputs, final_state, _ = decoder.dynamic_decode(dec) output_ids = final_outputs.rnn_output outputs = final_outputs.sample_id else: def embedding_fn(inputs): return tf.nn.embedding_lookup(embeddings_decoder, inputs) decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(self.iterator.source_length) maximum_iterations = tf.to_int32( tf.round( tf.to_float(max_encoder_length) * decoding_length_factor)) tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)), tf.int32) tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)), tf.int32) start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id if self.beam_width == 1: decoder_helper = helper.GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) else: dec = beam_search_decoder.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, output_layer=output_layer, beam_width=self.beam_width) final_outputs, final_state, _ = decoder.dynamic_decode( dec, # swap_memory=True, maximum_iterations=maximum_iterations) if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1: output_ids = final_outputs.sample_id outputs = final_outputs.rnn_output else: output_ids = final_outputs.predicted_ids outputs = final_outputs.beam_search_decoder_output.scores return output_ids, outputs
def _testDynamicDecodeRNN(self, time_major, has_attention): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.test_session() as sess: batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=False) cell_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, start_tokens=array_ops.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0) final_outputs, final_state, final_sequence_lengths = ( decoder.dynamic_decode(bsd, output_time_major=time_major, maximum_iterations=max_out)) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertTrue( isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)) self.assertTrue( isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( _t((batch_size, None, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, None, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( sess_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
from tensorflow.python.ops import rnn_cell # # tf.enable_eager_execution() batch_size = 5 src_len = [4, 5, 3, 5, 6] max_times = 6 num_units = 16 enc_output = tf.random.normal((batch_size, max_times, num_units), dtype=tf.float32) # # attenRNNCell rnncell = rnn_cell.LSTMCell(num_units=16) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=num_units, memory=enc_output, memory_sequence_length=src_len) attnRNNCell = attention_wrapper.AttentionWrapper( cell=rnncell, attention_mechanism=attention_mechanism, alignment_history=True) # training tgt_len = [5, 6, 2, 7, 4] tgt_max_times = 7 tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units), dtype=tf.float32) training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len) # train helper train_decoder = basic_decoder.BasicDecoder( cell=attnRNNCell, helper=training_helper, initial_state=attnRNNCell.zero_state(batch_size, tf.float32))