def BiRNNAtt(x, attention, layer) : with tf.variable_scope('encoder_sentence_{}'.format(layer),reuse=False): # Prepare data shape to match `rnn` function requirements # Current data input shape: (batch_size, timesteps, n_input) # Required shape: 'timesteps' tensors list of shape (batch_size, num_input) # Unstack to get a list of 'timesteps' tensors of shape (batch_size, num_input) # x = tf.unstack(x, max_length, 1) # Define lstm cells with tensorflow # Forward direction cell lstm_fw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0) lstm_fw_att = seq2seq.AttentionWrapper(lstm_fw_cell, attention) # Backward direction cell lstm_bw_cell = rnn.BasicLSTMCell(dims, forget_bias=1.0) lstm_bw_att = seq2seq.AttentionWrapper(lstm_bw_cell, attention) ((fw_outputs, bw_outputs), (fw_states, bw_states)) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_att, lstm_bw_att, x, dtype=tf.float32) outputs = tf.concat([fw_outputs, bw_outputs], axis=2) # print("BiLSTM lengths: ", len(outputs)) # Linear activation, using rnn inner loop last output return outputs
def getBeamSearchDecoderCell(self, encoder_outputs, encoder_final_states): basic_cells = [self.get_basicLSTMCell() for i in range(layer_num)] basic_cell = tf.nn.rnn_cell.MultiRNNCell(basic_cells) tiled_encoder_outputs = seq2seq.tile_batch(encoder_outputs, multiplier=beam_size) tiled_encoder_final_states = [ seq2seq.tile_batch(state, multiplier=beam_size) for state in encoder_final_states ] tiled_sequence_length = seq2seq.tile_batch(self.enc_len, multiplier=beam_size) initial_state = tuple(tiled_encoder_final_states) #attention attention_mechanism = seq2seq.BahdanauAttention( num_units=num_units, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length) att_cell = seq2seq.AttentionWrapper( basic_cell, attention_mechanism=attention_mechanism, attention_layer_size=num_units, alignment_history=False, cell_input_fn=None, initial_cell_state=initial_state) initial_state = att_cell.zero_state( batch_size=tf.shape(self.enc_in)[0] * beam_size, dtype=tf.float32) # att_state.clone(cell_state=encoder_final_state) return att_cell, initial_state
def create_decoder_cell(agenda, base_sent_embeds, mev_st, mev_ts, base_length, iw_length, dw_length, attn_dim, hidden_dim, num_layer, enable_alignment_history=False, enable_dropout=False, dropout_keep=0.1, no_insert_delete_attn=False): base_attn = seq2seq.BahdanauAttention(attn_dim, base_sent_embeds, base_length, name='src_attn') cnx_src, micro_evs_st = mev_st mev_st_attn = seq2seq.BahdanauAttention(attn_dim, cnx_src, iw_length, name='mev_st_attn') mev_st_attn._values = micro_evs_st attns = [base_attn, mev_st_attn] if not no_insert_delete_attn: cnx_tgt, micro_evs_ts = mev_ts mev_ts_attn = seq2seq.BahdanauAttention(attn_dim, cnx_tgt, dw_length, name='mev_ts_attn') mev_ts_attn._values = micro_evs_ts attns += [mev_ts_attn] bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell') bottom_attn_cell = seq2seq.AttentionWrapper( bottom_cell, tuple(attns), output_attention=False, alignment_history=enable_alignment_history, name='att_bottom_cell') all_cells = [bottom_attn_cell] num_layer -= 1 for i in range(num_layer): cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1)) if enable_dropout and dropout_keep < 1.: cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep) all_cells.append(cell) decoder_cell = AttentionAugmentRNNCell(all_cells) decoder_cell.set_agenda(agenda) return decoder_cell
def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, attention_mechanism=None, attention_num_heads=1, attention_layer_size=None, output_attention=False, single_cell_fn=None, trainable=True): """Returns an instance of an RNN cell. Args: unit_type: A string that specifies the type of the recurrent unit. Must be one of {"lstm", "gru", "lstm_norm", "nas"}. num_units: An integer for the numner of units per layer. num_layers: An integer for the number of recurrent layers. num_residual_layers: An integer for the number of residual layers. forget_bias: A float for the forget bias in LSTM cells. dropout: A float for the recurrent dropout rate. mode: TRAIN | EVAL | PREDICT attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism. attention_num_heads: An integer for the number of attention heads. attention_layer_size: Optional integer for the size of the attention layer. output_attention: A boolean indicating whether RNN cell outputs attention. single_cell_fn: A function for building a single RNN cell. trainable: A boolean indicating whether the cell weights are trainable. Returns: An RNNCell instance. """ cell_list = _cell_list( unit_type=unit_type, num_units=num_units, num_layers=num_layers, forget_bias=forget_bias, dropout=dropout, mode=mode, num_residual_layers=num_residual_layers, single_cell_fn=single_cell_fn, trainable=trainable) if len(cell_list) == 1: # Single layer. cell = cell_list[0] else: # Multiple layers. cell = contrib_rnn.MultiRNNCell(cell_list) # Wrap with attention, if necessary. if attention_mechanism is not None: cell = contrib_seq2seq.AttentionWrapper( cell, [attention_mechanism] * attention_num_heads, attention_layer_size=[attention_layer_size] * attention_num_heads, alignment_history=False, output_attention=output_attention, name="attention") return cell
def decode(self, dec_cell, enc_outputs, ctx_outputs): with tf.variable_scope("decode"): batch_size = self._batch_size attn_mech = seq2seq.BahdanauAttention(self._memory_size, enc_outputs, self.input_lengths) dec_cell = CondWrapper(dec_cell, ctx_outputs) dec_cell = seq2seq.AttentionWrapper(dec_cell, attn_mech, self._memory_size) dec_initial_state = dec_cell.zero_state(batch_size=batch_size, dtype=tf.float32) helper_build_fn = self._infer_helper if self._infer else self._train_helper output_layer = layers_core.Dense(self._vocab_size, use_bias=True, activation=None) decoder = seq2seq.BasicDecoder(cell=dec_cell, helper=helper_build_fn(), initial_state=dec_initial_state, output_layer=output_layer) dec_output, dec_state = seq2seq.dynamic_decode( decoder, impute_finished=True, maximum_iterations=self._max_seq_length) rnn_output = dec_output.rnn_output sample_id = dec_output.sample_id return rnn_output, sample_id, dec_state
def build_attention_cell(self, encoder_outputs, encoder_states): memory = encoder_outputs if self.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) attention_mechanism = seq2seq.LuongAttention( num_units=self.hps.att_num_units, memory=memory, memory_sequence_length=self.iterator.target_length) cell = rnn.MultiRNNCell([ self.build_rnn_cell(FLAGS.cell_type) for _ in range(self.hps.stack_layers) ]) cell = seq2seq.AttentionWrapper( cell, attention_mechanism, attention_layer_size=self.hps.att_num_units, name='attention') batch_size = tf.size(self.iterator.source_length) decoder_initial_state = cell.zero_state( batch_size=batch_size, dtype=dtype).clone(cell_state=encoder_states) return cell, decoder_initial_state
def inference_decode_layer(self, start_token, dec_cell, end_token, output_layer): start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32), [self.batch_size], name='start_token') tiled_enc_output = seq2seq.tile_batch(self.enc_output, multiplier=self.Beam_width) tiled_enc_state = seq2seq.tile_batch(self.enc_state, multiplier=self.Beam_width) tiled_source_len = seq2seq.tile_batch(self.source_len, multiplier=self.Beam_width) atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2, tiled_enc_output, tiled_source_len, normalize=True) decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech, self.hidden_dim * 2) initial_state = decoder_att.zero_state( self.batch_size * self.Beam_width, tf.float32).clone(cell_state=tiled_enc_state) decoder = seq2seq.BeamSearchDecoder(decoder_att, self.embeddings, start_tokens, end_token, initial_state, beam_width=self.Beam_width, output_layer=output_layer) infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False, self.max_target_len) return infer_logits
def train_decode_layer(self, dec_embeddig_input, dec_cell, output_layer): atten_mech = seq2seq.BahdanauAttention( num_units=self.hidden_dim * 2, memory=self.enc_output, memory_sequence_length=self.target_len, normalize=True, name='BahadanauAttention') dec_cell = seq2seq.AttentionWrapper(dec_cell, atten_mech, self.hidden_dim * 2, name='dec_attention_cell') initial_state = dec_cell.zero_state( batch_size=self.batch_size, dtype=tf.float32).clone(cell_state=self.enc_state) train_helper = seq2seq.TrainingHelper(dec_embeddig_input, self.target_len) training_decoder = seq2seq.BasicDecoder(dec_cell, train_helper, initial_state=initial_state, output_layer=output_layer) train_logits, _, _ = seq2seq.dynamic_decode( training_decoder, output_time_major=False, impute_finished=False, maximum_iterations=self.max_target_len) return train_logits
def add_attention( cells, attention_types, num_units, memory, memory_len, mode, batch_size, dtype, beam_search=False, beam_width=None, initial_state=None, write_attention_alignment=False, fusion_type='linear_fusion', ): r""" Wraps the decoder_cells with an AttentionWrapper Args: cells: instances of `RNNCell` beam_search: `bool` flag for beam search decoders batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state Returns: attention_cells: the Attention wrapped decoder cells initial_state: a proper initial state to be used with the returned cells """ attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms( beam_search=beam_search, beam_width=beam_width, memory=memory, memory_len=memory_len, num_units=num_units, attention_types=attention_types, fusion_type=fusion_type, mode=mode, dtype=dtype) if beam_search is True: initial_state = seq2seq.tile_batch(initial_state, multiplier=beam_width) attention_cells = seq2seq.AttentionWrapper( cell=cells, attention_mechanism=attention_mechanisms, attention_layer_size=attention_layer_sizes, # initial_cell_state=decoder_initial_state, alignment_history=write_attention_alignment, output_attention=output_attention, attention_layer=attention_layers, ) attn_zero = attention_cells.zero_state( dtype=dtype, batch_size=batch_size * beam_width if beam_search is True else batch_size) if initial_state is not None: initial_state = attn_zero.clone(cell_state=initial_state) return attention_cells, initial_state
def decoding_layer(self,dec_embed_input,embeddings,enc_output,enc_state, vocab_size,text_len,summary_len,max_sum_len): lstm = rnn.LSTMCell(self.hidden_dim * 2,initializer=tf.random_normal_initializer(-0.1,0.1,seed=2)) dec_cell = rnn.DropoutWrapper(lstm,input_keep_prob=self.keep_prob,) output_layer = tf.layers.Dense(vocab_size,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1)) attn_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2, enc_output, text_len, normalize=False,name='BahdanauAttention') dec_cell = seq2seq.AttentionWrapper(dec_cell,attn_mech,attention_layer_size=self.hidden_dim * 2) # initial_state = seq2seq.AttentionWrapperState(enc_state[0],_zero_state_tensors(self.hidden_dim,batch_size, # tf.float32)) initial_state = dec_cell.zero_state(self.batch_size,tf.float32).clone(cell_state=LSTMStateTuple(*enc_state)) with tf.variable_scope('decode'): traing_logits = self.training_decoding_layer(dec_embed_input,summary_len,dec_cell,initial_state, output_layer,max_sum_len) with tf.variable_scope('decode',reuse=True): inference_logits = self.inference_decoding_layer(embeddings,self.vocab_to_int['<GO>'], self.vocab_to_int['<EOS>'],dec_cell, initial_state,output_layer,max_sum_len) return traing_logits, inference_logits
def _build_decoder(self): """ Decode keyword and context into a sequence of vectors. """ self.sequence_decoder = tf.placeholder( dtype=tf.float32, shape=[_BATCH_SIZE, None, CHAR_VEC_DIM], name='context') self.length_decoder = tf.placeholder(dtype=tf.int32, shape=[_BATCH_SIZE], name='length_keywords') attention = seq2seq.BahdanauAttention( _NUM_UNITS, memory=self.encoder_outputs, memory_sequence_length=self.context_length, name="BahdanauAttention") cell_attention = tf.contrib.rnn.GRUCell(_NUM_UNITS) attention_wrapper = seq2seq.AttentionWrapper(cell_attention, attention) self.initial_decode_state = attention_wrapper.zero_state( _BATCH_SIZE, dtype=tf.float32).clone(cell_state=self.states_keywords) self.decoder_outputs, self.decoder_final_state = tf.nn.dynamic_rnn( attention_wrapper, self.sequence_decoder, sequence_length=self.length_decoder, initial_state=self.initial_decode_state, dtype=tf.float32, time_major=False)
def _build_model(self, batch_size, helper_build_fn, decoder_maxiters=None, alignment_history=False): # embed input_data into a one-hot representation inputs = tf.one_hot(self.input_data, self._input_size, dtype=self._dtype) inputs_len = self.input_lengths with tf.name_scope('conv-encoder'): W = tf.Variable(tf.truncated_normal( [3, self._input_size, self._enc_size], stddev=0.1), name="conv-weights") b = tf.Variable(tf.truncated_normal([self._enc_size], stddev=0.1), name="conv-bias") enc_out = tf.nn.elu( tf.nn.conv1d(inputs, W, stride=1, padding='SAME') + b) with tf.name_scope('attn-decoder'): dec_cell_in1 = rnn.GRUCell(self._dec_size) dec_cell_in2 = rnn.GRUCell(self._dec_size) memory = enc_out attn_mech = seq2seq.LuongMonotonicAttention( self._enc_size, memory, memory_sequence_length=inputs_len, sigmoid_noise=0.5, score_bias_init=-4., mode='recursive', scale=True) dec_cell_attn = rnn.MultiRNNCell( [rnn.GRUCell(self._dec_size), rnn.GRUCell(self._enc_size)], state_is_tuple=True) dec_cell_attn = seq2seq.AttentionWrapper( dec_cell_attn, attn_mech, attention_layer_size=self._enc_size, alignment_history=alignment_history) dec_cell_out = rnn.GRUCell(self._output_size) dec_cell = rnn.MultiRNNCell( [dec_cell_in1, dec_cell_in2, dec_cell_attn, dec_cell_out], state_is_tuple=True) dec = seq2seq.BasicDecoder( dec_cell, helper_build_fn(), dec_cell.zero_state(batch_size, self._dtype)) dec_out, dec_state, _ = seq2seq.dynamic_decode( dec, output_time_major=False, maximum_iterations=decoder_maxiters, impute_finished=True) self.outputs = dec_out.rnn_output self.output_ids = dec_out.sample_id self.final_state = dec_state
def _build_decoder_beam_search(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) attention_mechanisms, layer_sizes = self._create_attention_mechanisms( beam_search=True) decoder_initial_state_tiled = seq2seq.tile_batch( self._decoder_initial_state, multiplier=self._hparams.beam_width) if self._hparams.enable_attention is True: attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=decoder_initial_state_tiled, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention) initial_state = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width) initial_state = initial_state.clone( cell_state=decoder_initial_state_tiled) cells = attention_cells else: cells = self._decoder_cells initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.5, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary( states) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids self.beam_search_output = outputs.beam_search_decoder_output
def _build_model(self, batch_size, helper_build_fn, decoder_maxiters=None, alignment_history=False): # embed input_data into a one-hot representation inputs = tf.one_hot(self.input_data, self._input_size, dtype=self._dtype) inputs_len = self.input_lengths with tf.name_scope('bidir-encoder'): fw_cell = rnn.MultiRNNCell( [rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True) bw_cell = rnn.MultiRNNCell( [rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True) fw_cell_zero = fw_cell.zero_state(batch_size, self._dtype) bw_cell_zero = bw_cell.zero_state(batch_size, self._dtype) enc_out, _ = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, inputs, sequence_length=inputs_len, initial_state_fw=fw_cell_zero, initial_state_bw=bw_cell_zero) with tf.name_scope('attn-decoder'): dec_cell_in = rnn.GRUCell(self._dec_rnn_size) attn_values = tf.concat(enc_out, 2) attn_mech = seq2seq.BahdanauAttention(self._enc_rnn_size * 2, attn_values, inputs_len) dec_cell_attn = rnn.GRUCell(self._enc_rnn_size * 2) dec_cell_attn = seq2seq.AttentionWrapper( dec_cell_attn, attn_mech, self._enc_rnn_size * 2, alignment_history=alignment_history) dec_cell_out = rnn.GRUCell(self._output_size) dec_cell = rnn.MultiRNNCell( [dec_cell_in, dec_cell_attn, dec_cell_out], state_is_tuple=True) dec = seq2seq.BasicDecoder( dec_cell, helper_build_fn(), dec_cell.zero_state(batch_size, self._dtype)) dec_out, dec_state = seq2seq.dynamic_decode( dec, output_time_major=False, maximum_iterations=decoder_maxiters, impute_finished=True) self.outputs = dec_out.rnn_output self.output_ids = dec_out.sample_id self.final_state = dec_state
def build_decoder_cell(self): if self.use_beamsearch_decode: encoder_outputs = tf.contrib.seq2seq.tile_batch( self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = tf.contrib.seq2seq.tile_batch( self.encoder_last_state, multiplier=self.beam_width) encoder_inputs_length = tf.contrib.seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) else: encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length self.attention_mechanism = seq2seq.BahdanauAttention( num_units=self.decoder_hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) self.decoder_cell_list = [ self.build_single_cell(self.decoder_hidden_units) for _ in range(self.depth) ] # NOTE(sdsuo): Not sure what this does yet def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.decoder_hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(rnn.array_ops.concat([inputs, attention], -1)) # Attention mechanism is implemented only on all decoder layers self.decoder_cell_list = seq2seq.AttentionWrapper( cell=rnn.MultiRNNCell(self.decoder_cell_list), attention_mechanism=self.attention_mechanism, attention_layer_size=self.decoder_hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state, alignment_history=False, name='attention_wrapper') if self.use_beamsearch_decode: batch_size = self.batch_size * self.beam_width else: batch_size = self.batch_size # add by Meng decoder_initial_state = self.decoder_cell_list.zero_state( batch_size=batch_size, dtype=self.dtype).clone(cell_state=encoder_last_state) return self.decoder_cell_list, decoder_initial_state
def build_decoder_cell(self): # TODO(sdsuo): Read up and decide whether to use beam search self.attention_mechanism = seq2seq.BahdanauAttention( num_units=self.decoder_hidden_units, memory=self.encoder_outputs, memory_sequence_length=self.encoder_inputs_length ) self.decoder_cell_list = [ self.build_single_cell(self.decoder_hidden_units) for _ in range(self.depth) ] # NOTE(sdsuo): Not sure what this does yet def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.decoder_hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(rnn.array_ops.concat([inputs, attention], -1)) # NOTE(sdsuo): Attention mechanism is implemented only on the top decoder layer self.decoder_cell_list[-1] = seq2seq.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.decoder_hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=self.encoder_last_state[-1], alignment_history=False, name='attention_wrapper' ) # NOTE(sdsuo): Not sure why this is necessary # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state # Also if beamsearch decoding is used, the batch_size argument in .zero_state # should be ${decoder_beam_width} times to the origianl batch_size if self.use_beamsearch_decode: batch_size = self.batch_size * self.beam_width else: batch_size = self.batch_size # NOTE(vera): important dimension here # embed() initial_state = [state for state in self.encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype ) decoder_initial_state = tuple(initial_state) return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _build_decoder(self, encoder_states, target_sequence, keep_prob, sampling_prob, attention_mechanism): """Define decoder architecture. """ # connect each layer sequentially, building a graph that resembles a # feed-forward network made of recurrent units decoder_cell = self._multi_cell(num_units=self.num_units, num_layers=self.num_layers, keep_prob=keep_prob) # connect attention to decoder attention_layer_size = self.num_units decoder = seq2seq.AttentionWrapper( cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_layer_size) # decoder start symbol decoder_raw_seq = target_sequence[:, :-1] prefix = tf.fill([tf.shape(target_sequence)[0], 1, self.target_depth], 0.0) decoder_input_seq = tf.concat([prefix, decoder_raw_seq], axis=1) # the model is using fixed lengths of target sequences so tile the defined # length in the batch dimension decoder_sequence_length = tf.tile([self.target_length], [tf.shape(target_sequence)[0]]) # decoder sampling scheduler feeds decoder output to next time input # instead of using ground-truth target vals during training helper = seq2seq.ScheduledOutputTrainingHelper( inputs=decoder_input_seq, sequence_length=decoder_sequence_length, sampling_probability=sampling_prob) # output layer projection_layer = Dense(units=self.target_depth, use_bias=True) # clone encoder state initial_state = decoder.zero_state( tf.shape(target_sequence)[0], tf.float32) initial_state = initial_state.clone(cell_state=encoder_states) # wrapper for decoder decoder = seq2seq.BasicDecoder(cell=decoder, helper=helper, initial_state=initial_state, output_layer=projection_layer) # build the unrolled graph of the recurrent neural network outputs, decoder_state, _sequence_lengths = seq2seq.dynamic_decode( decoder=decoder, maximum_iterations=self.target_length) return (outputs, decoder_state)
def decoder_cell(self, inputs, lengths): attention_mechanism = seq2seq.LuongAttention( num_units=self.layer_size, memory=inputs, memory_sequence_length=lengths, scale=True) return seq2seq.AttentionWrapper( cell=self.cell(), attention_mechanism=attention_mechanism, attention_layer_size=self.layer_size)
def build_decode_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = seq2seq.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length ) if self.attention_type.lower() == 'luong': self.attention_mechanism = seq2seq.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length ) # decoder_cell self.decoder_cell_list = [self.build_single_cell(layer=2) for _ in range(self.depth)] def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units * 2, dtype=self.dtype, name='attn_input_feeding') return _input_layer(tf.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = seq2seq.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_wrapper' ) # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state batch_size = self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype ) decoder_initial_state = tuple(initial_state) return rnn.MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _build_decoder_train(self): self._labels_embedded = tf.nn.embedding_lookup(self._embedding_matrix, self._labels_padded_GO) self._helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._labels_embedded, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) if self._hparams.enable_attention is True: attention_mechanisms, layer_sizes = self._create_attention_mechanisms( ) attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=self._decoder_initial_state, alignment_history=False, output_attention=self._output_attention, ) batch_size, _ = tf.unstack(tf.shape(self._labels)) attn_zero = attention_cells.zero_state(dtype=self._hparams.dtype, batch_size=batch_size) initial_state = attn_zero.clone( cell_state=self._decoder_initial_state) cells = attention_cells else: cells = self._decoder_cells initial_state = self._decoder_initial_state self._decoder_train = seq2seq.BasicDecoder( cell=cells, helper=self._helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) self._basic_decoder_train_outputs, self._final_states, self._final_seq_lens = seq2seq.dynamic_decode( self._decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) self._logits = self._basic_decoder_train_outputs.rnn_output
def _init_encoder(self): with tf.variable_scope("Encoder") as scope: encoder_inputs = self._maybe_add_dense_layers() if self._hparams.encoder_type == 'unidirectional': self._encoder_cells = build_rnn_layers( cell_type=self._hparams.cell_type, num_units_per_layer=self._num_units_per_layer, use_dropout=self._hparams.use_dropout, dropout_probability=self._hparams.dropout_probability, mode=self._mode, as_list=True, dtype=self._hparams.dtype) attention_mechanism, output_attention = create_attention_mechanism( attention_type=self._hparams.attention_type[0][0], num_units=self._num_units_per_layer[-1], memory=self._attended_memory, memory_sequence_length=self._attended_memory_length, mode=self._mode, dtype=self._hparams.dtype) attention_cells = seq2seq.AttentionWrapper( cell=self._encoder_cells[-1], attention_mechanism=attention_mechanism, attention_layer_size=self._hparams. decoder_units_per_layer[-1], alignment_history=self._hparams.write_attention_alignment, output_attention=output_attention, ) self._encoder_cells[-1] = attention_cells self._encoder_outputs, self._encoder_final_state = tf.nn.dynamic_rnn( cell=MultiRNNCell(self._encoder_cells), inputs=encoder_inputs, sequence_length=self._inputs_len, parallel_iterations=self._hparams. batch_size[0 if self._mode == 'train' else 1], swap_memory=False, dtype=self._hparams.dtype, scope=scope, ) if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary( self._encoder_final_state[-1])
def _build_decoder_greedy(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) self._helper_greedy = seq2seq.GreedyEmbeddingHelper( embedding=self._embedding_matrix, start_tokens=tf.tile([self._GO_ID], [batch_size]), end_token=self._EOS_ID) if self._hparams.enable_attention is True: attention_mechanisms, layer_sizes = self._create_attention_mechanisms() attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=self._decoder_initial_state, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention ) attn_zero = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size ) initial_state = attn_zero.clone( cell_state=self._decoder_initial_state ) cells = attention_cells else: cells = self._decoder_cells initial_state = self._decoder_initial_state self._decoder_inference = seq2seq.BasicDecoder( cell=cells, helper=self._helper_greedy, initial_state=initial_state, output_layer=self._dense_layer) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=True, swap_memory=False, maximum_iterations=self._hparams.max_label_length) # self._result = outputs, states, lengths self.inference_outputs = outputs.rnn_output self.inference_predicted_ids = outputs.sample_id if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary(states, )
def _decoder_cell(self): batch_size, _ = tf.unstack(tf.shape(self._targets)) attention = seq2seq.BahdanauAttention( num_units=2 * self.CELL_SIZE, memory=self._targets_encoder_outputs, memory_sequence_length=self._targets_length) attentive_cell = seq2seq.AttentionWrapper( cell=rnn.GRUCell(2 * self.CELL_SIZE, activation=tf.nn.tanh), attention_mechanism=attention, attention_layer_size=2 * self.CELL_SIZE, initial_cell_state=self._targets_encoder_state) return ( attentive_cell, attentive_cell.zero_state(batch_size, tf.float32), )
def _build_decoder_cell(self, encoder_outputs, encoder_state, source_sequence_length): beam_width = self.hparams.beam_width if self.hparams.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) if self.mode == PREDICT and beam_width > 0: memory = seq2seq.tile_batch(memory, beam_width) source_sequence_length = seq2seq.tile_batch( source_sequence_length, beam_width) encoder_state = seq2seq.tile_batch(encoder_state, beam_width) batch_size = self.batch_size * beam_width else: batch_size = self.batch_size # Use Attention Mechanism attention_machanism = seq2seq.LuongAttention( num_units=self.hparams.num_units, memory=memory, memory_sequence_length=source_sequence_length) cell = model_helper.build_rnn_cell( self.hparams.unit_type, self.hparams.num_units, self.hparams.num_layers, self.hparams.dropout, ) alignment_history = (self.mode == PREDICT and beam_width == 0) cell = seq2seq.AttentionWrapper( cell, attention_machanism, attention_layer_size=self.hparams.num_units, alignment_history=alignment_history, name='attention') if self.hparams.pass_hidden_state: initial_state = cell.zero_state( batch_size, tf.float32).clone(cell_state=encoder_state) else: initial_state = cell.zero_state(batch_size, tf.float32) return cell, initial_state
def _decoder(self, keep_prob, encoder_output, encoder_state, batch_size, scope, helper, reuse=None): with tf.variable_scope(scope, reuse=reuse): attention_states = encoder_output cell = rnn.MultiRNNCell([self._cell(keep_prob) for _ in range(self.lstm_dims)]) attention_mechanism = seq2seq.BahdanauAttention(self.hidden_size, attention_states) # attention decoder_cell = seq2seq.AttentionWrapper(cell, attention_mechanism, attention_layer_size=self.hidden_size // 2) decoder_cell = rnn.OutputProjectionWrapper(decoder_cell, self.hidden_size, reuse=reuse, activation=tf.nn.leaky_relu) decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state) output_layer = tf.layers.Dense(self.num_words, kernel_initializer=tf.contrib.layers.xavier_initializer(), activation=tf.nn.leaky_relu) decoder = seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=output_layer) output, _, _ = seq2seq.dynamic_decode(decoder, maximum_iterations=self.max_sentence_length, impute_finished=True) # tf.summary.histogram('decoder', output) return output
def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state, vocab_size, text_length, summary_length, max_summary_length, rnn_size, vocab_to_int, keep_prob, batch_size, num_layers): for layer in range(num_layers): with tf.variable_scope('decoder_{}'.format(layer)): lstm = tf.nn.rnn_cell.LSTMCell( rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) dec_cell = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob=keep_prob) #全连接层 output_layer = Dense(vocab_size, kernel_initializer=tf.truncated_normal_initializer( mean=0.0, stddev=0.1)) attn_mech = seq.BahdanauAttention(rnn_size, enc_output, text_length, normalize=False, name='BahdanauAttention') dec_cell = seq.AttentionWrapper(cell=dec_cell, attention_mechanism=attn_mech, attention_layer_size=rnn_size) # 引入注意力机制 initial_state = seq.AttentionWrapperState( enc_state[0], _zero_state_tensors(rnn_size, batch_size, tf.float32)) with tf.variable_scope("decode"): training_logits = training_decoding_layer(dec_embed_input, summary_length, dec_cell, initial_state, output_layer, vocab_size, max_summary_length) with tf.variable_scope("decode", reuse=True): inference_logits = inference_decoding_layer( embeddings, vocab_to_int['<GO>'], vocab_to_int['<EOS>'], dec_cell, initial_state, output_layer, max_summary_length, batch_size) return training_logits, inference_logits
def decoding_layer_train(self, num_units, max_time, batch_size, char2numY, data_output_embed, encoder_output, last_state, bidirectional): if not bidirectional: decoder_cell = rnn.LSTMCell(num_units) else: decoder_cell = rnn.LSTMCell(2 * num_units) training_helper = seq2seq.TrainingHelper(inputs=data_output_embed, sequence_length=[max_time] * batch_size, time_major=False) attention_mechanism = seq2seq.BahdanauAttention( num_units=num_units, memory=encoder_output, memory_sequence_length=[max_time] * batch_size) attention_cell = seq2seq.AttentionWrapper( cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=num_units) decoder_initial_state = attention_cell.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=last_state) output_layer = tf.layers.Dense( len(char2numY) - 2, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) training_decoder = seq2seq.BasicDecoder( cell=attention_cell, helper=training_helper, initial_state=decoder_initial_state, output_layer=output_layer) train_outputs, _, _ = seq2seq.dynamic_decode( decoder=training_decoder, impute_finished=True, maximum_iterations=max_time) return train_outputs
def wrap_att(self, dec_cell, lstm_size, enc_output, lengths, alignment_history=False): """ Wrap a decoder cell within an attention cell like in the paper: global Luong attention. """ attention_mechanism = s2s.LuongAttention( num_units=lstm_size, memory=enc_output, memory_sequence_length=lengths, name='LuongAttention') # wrapp as a seq2seq AttentionWrapper return s2s.AttentionWrapper(cell=dec_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False, alignment_history=alignment_history)
def create_decoder_cell(agenda, base_sent_embeds, insert_word_embeds, delete_word_embeds, base_length, iw_length, dw_length, attn_dim, hidden_dim, num_layer, enable_alignment_history=False, enable_dropout=False, dropout_keep=0.1, no_insert_delete_attn=False): base_attn = seq2seq.BahdanauAttention(attn_dim, base_sent_embeds, base_length, name='src_attn') attns = [base_attn] if not no_insert_delete_attn: insert_attn = seq2seq.BahdanauAttention(attn_dim, insert_word_embeds, iw_length, name='insert_attn') delete_attn = seq2seq.BahdanauAttention(attn_dim, delete_word_embeds, dw_length, name='delete_attn') attns += [insert_attn, delete_attn] if no_insert_delete_attn: assert len(attns) == 1 else: assert len(attns) == 3 bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell') bottom_attn_cell = seq2seq.AttentionWrapper( bottom_cell, tuple(attns), output_attention=False, alignment_history=enable_alignment_history, name='att_bottom_cell' ) all_cells = [bottom_attn_cell] num_layer -= 1 for i in range(num_layer): cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1)) if enable_dropout and dropout_keep < 1.: cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep) all_cells.append(cell) decoder_cell = AttentionAugmentRNNCell(all_cells) decoder_cell.set_agenda(agenda) return decoder_cell
def build_model(self): encoder = self.encoder inputs = self.inputs with tf.variable_scope('encoder'): t_sequence = tf.unstack(inputs, axis=1, name='TimeMajorInputs') outputs, _, _ = tf.nn.static_bidirectional_rnn(cell_fw=encoder, cell_bw=encoder, inputs=t_sequence, dtype=inputs.dtype) with tf.variable_scope('decoder'): with tf.name_scope('attention'): memory = tf.stack(outputs, axis=1, name='BatchMajorAnnotations') self.bahdanau = seq2seq.BahdanauAttention(self.attention_size, memory=memory) raw_decoder = self.decoder decoder_cell = seq2seq.AttentionWrapper(raw_decoder, self.bahdanau, output_attention=False) self.decoder_cell = decoder_cell