def inference_decode_layer(self, start_token, dec_cell, end_token, output_layer): start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32), [self.batch_size], name='start_token') tiled_enc_output = seq2seq.tile_batch(self.enc_output, multiplier=self.Beam_width) tiled_enc_state = seq2seq.tile_batch(self.enc_state, multiplier=self.Beam_width) tiled_source_len = seq2seq.tile_batch(self.source_len, multiplier=self.Beam_width) atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2, tiled_enc_output, tiled_source_len, normalize=True) decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech, self.hidden_dim * 2) initial_state = decoder_att.zero_state( self.batch_size * self.Beam_width, tf.float32).clone(cell_state=tiled_enc_state) decoder = seq2seq.BeamSearchDecoder(decoder_att, self.embeddings, start_tokens, end_token, initial_state, beam_width=self.Beam_width, output_layer=output_layer) infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False, self.max_target_len) return infer_logits
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length if self.use_beamsearch_decode: print ("use beamsearch decoding..") encoder_outputs = seq2seq.tile_batch( self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length,) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for i in range(self.depth)] decoder_initial_state = encoder_last_state def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') return _input_layer(array_ops.concat([inputs, attention], -1)) # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.hidden_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=encoder_last_state[-1], alignment_history=False, name='Attention_Wrapper') batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def _build_decoder_cell(self, enc_outputs, enc_state): beam_size = self.config.beam_size context_length = self.source_length memory = enc_outputs if self.mode == ModelMode.infer and beam_size > 0: enc_state = tc_seq2seq.tile_batch(enc_state, multiplier=beam_size) memory = tc_seq2seq.tile_batch(memory, multiplier=beam_size) context_length = tc_seq2seq.tile_batch(context_length, multiplier=beam_size) batch_size = self.batch_size * beam_size else: enc_state = enc_state batch_size = self.batch_size dec_cell = get_rnn_cell(self.config.unit_type, hidden_size=self.config.dec_hidden_size, num_layers=self.config.num_layers, dropout_keep_prob=self.dropout_keep_prob) return dec_cell, enc_state
def _create_attention_mechanisms(self, beam_search=False): r""" Creates a list of attention mechanisms (e.g. seq2seq.BahdanauAttention) and also a list of ints holding the attention projection layer size Args: beam_search: `bool`, whether the beam-search decoding algorithm is used or not """ mechanisms = [] layer_sizes = [] if beam_search is True: encoder_memory = seq2seq.tile_batch( self._encoder_memory, multiplier=self._hparams.beam_width) encoder_features_len = seq2seq.tile_batch( self._encoder_features_len, multiplier=self._hparams.beam_width) else: encoder_memory = self._encoder_memory encoder_features_len = self._encoder_features_len for attention_type in self._hparams.attention_type[0]: attention = self._create_attention_mechanism( num_units=self._hparams.decoder_units_per_layer[-1], memory=encoder_memory, memory_sequence_length=encoder_features_len, attention_type=attention_type ) mechanisms.append(attention) layer_sizes.append(self._hparams.decoder_units_per_layer[-1]) return mechanisms, layer_sizes
def _build_decoder_cell(self, enc_outputs, enc_state): beam_size = self.config.beam_size # 5,beam search context_length = self.source_length memory = enc_outputs # 预测阶段才进行beam search if self.mode == ModelMode.infer and beam_size > 0: # beam_search_decoder里面的函数 enc_state = tc_seq2seq.tile_batch(enc_state, multiplier=beam_size) memory = tc_seq2seq.tile_batch(memory, multiplier=beam_size) context_length = tc_seq2seq.tile_batch(context_length, multiplier=beam_size) else: enc_state = enc_state batch_size = self.batch_size dec_cell = get_rnn_cell( self.config.unit_type, # lstm hidden_size=self.config.dec_hidden_size, # 300 num_layers=1, # 1 dropout_keep_prob=self.dropout_keep_prob) return dec_cell, enc_state
def build_decoder_cell(self): encoder_inputs_length = self.encoder_inputs_length if self.beam_search: print("use beamsearch decoding..") self.encoder_outputs = tile_batch(self.encoder_outputs, multiplier=self.beam_size) self.encoder_state = nest.map_structure( lambda s: tile_batch(s, self.beam_size), self.encoder_state) encoder_inputs_length = tile_batch(encoder_inputs_length, multiplier=self.beam_size) # 定义要使用的attention机制。 attention_mechanism = BahdanauAttention( num_units=self.rnn_size, memory=self.encoder_outputs, memory_sequence_length=encoder_inputs_length) # 定义decoder阶段要使用的RNNCell,然后为其封装attention wrapper decoder_cell = self.create_rnn_cell() decoder_cell = AttentionWrapper( cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper') batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size decoder_initial_state = decoder_cell.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=self.encoder_state) return decoder_cell, decoder_initial_state
def getBeamSearchDecoderCell(self, encoder_outputs, encoder_final_states): basic_cells = [self.get_basicLSTMCell() for i in range(layer_num)] basic_cell = tf.nn.rnn_cell.MultiRNNCell(basic_cells) tiled_encoder_outputs = seq2seq.tile_batch(encoder_outputs, multiplier=beam_size) tiled_encoder_final_states = [ seq2seq.tile_batch(state, multiplier=beam_size) for state in encoder_final_states ] tiled_sequence_length = seq2seq.tile_batch(self.enc_len, multiplier=beam_size) initial_state = tuple(tiled_encoder_final_states) #attention attention_mechanism = seq2seq.BahdanauAttention( num_units=num_units, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length) att_cell = seq2seq.AttentionWrapper( basic_cell, attention_mechanism=attention_mechanism, attention_layer_size=num_units, alignment_history=False, cell_input_fn=None, initial_cell_state=initial_state) initial_state = att_cell.zero_state( batch_size=tf.shape(self.enc_in)[0] * beam_size, dtype=tf.float32) # att_state.clone(cell_state=encoder_final_state) return att_cell, initial_state
def _get_beam_search_cell(self, beam_width): """Returns the RNN cell for beam search decoding. """ with tf.variable_scope(self.variable_scope, reuse=True): attn_kwargs = copy.copy(self._attn_kwargs) memory = attn_kwargs['memory'] attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width) memory_seq_length = attn_kwargs['memory_sequence_length'] if memory_seq_length is not None: attn_kwargs['memory_sequence_length'] = tile_batch( memory_seq_length, beam_width) attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom'] bs_attention_mechanism = utils.check_or_get_instance( self._hparams.attention.type, attn_kwargs, attn_modules, classtype=tf.contrib.seq2seq.AttentionMechanism) bs_attn_cell = AttentionWrapper(self._cell._cell, bs_attention_mechanism, cell_input_fn=self._cell_input_fn, **self._attn_cell_kwargs) self._beam_search_cell = bs_attn_cell return bs_attn_cell
def beam_eval_decoder(agenda, embeddings, extended_base_words, oov, start_token_id, stop_token_id, base_sent_hiddens, base_length, vocab_size, attn_dim, hidden_dim, num_layer, max_sentence_length, beam_width, swap_memory, enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False): with tf.variable_scope(OPS_NAME, 'decoder', reuse=True): true_batch_size = tf.shape(base_sent_hiddens)[0] tiled_agenda = seq2seq.tile_batch(agenda, beam_width) tiled_extended_base_words = seq2seq.tile_batch(extended_base_words, beam_width) tiled_oov = seq2seq.tile_batch(oov, beam_width) tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width) tiled_base_lengths = seq2seq.tile_batch(base_length, beam_width) start_token_id = tf.cast(start_token_id, tf.int32) stop_token_id = tf.cast(stop_token_id, tf.int32) cell, zero_states = create_decoder_cell( tiled_agenda, tiled_extended_base_words, tiled_oov, tiled_base_sent, tiled_base_lengths, vocab_size, attn_dim, hidden_dim, num_layer, enable_dropout=enable_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn, beam_width=beam_width) decoder = seq2seq.BeamSearchDecoder(cell, create_embedding_fn(vocab_size), tf.fill([true_batch_size], start_token_id), stop_token_id, zero_states, beam_width=beam_width, length_penalty_weight=0.0) return seq2seq.dynamic_decode(decoder, maximum_iterations=max_sentence_length, swap_memory=swap_memory)
def _build_decoder_cell(self): # no beam encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs _input_layer = Dense(self.hidden_units, dtype=self.dtype, name="attn_input_feeding") return _input_layer(array_ops.concat([inputs, attention], -1)) # attention mechanism 'luong' with tf.variable_scope('shared_attention_mechanism'): self.attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \ memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # build decoder cell self.init_decoder_cell_list = [self._build_single_cell() for i in range(self.depth)] decoder_initial_state = encoder_last_state self.decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\ cell = self.init_decoder_cell_list[-1], \ attention_mechanism=self.attention_mechanism,\ attention_layer_size=self.hidden_units,\ cell_input_fn=attn_decoder_input_fn,\ initial_cell_state=encoder_last_state[-1],\ alignment_history=False)] batch_size = self.batch_size initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state(batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) # beam beam_encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) beam_encoder_last_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) beam_encoder_inputs_length = seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_width) with tf.variable_scope('shared_attention_mechanism', reuse=True): self.beam_attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \ memory=beam_encoder_outputs, \ memory_sequence_length=beam_encoder_inputs_length) beam_decoder_initial_state = beam_encoder_last_state self.beam_decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\ cell = self.init_decoder_cell_list[-1], \ attention_mechanism=self.beam_attention_mechanism,\ attention_layer_size=self.hidden_units,\ cell_input_fn=attn_decoder_input_fn,\ initial_cell_state=beam_encoder_last_state[-1],\ alignment_history=False)] beam_batch_size = self.batch_size * self.beam_width beam_initial_state = [state for state in beam_encoder_last_state] beam_initial_state[-1] = self.beam_decoder_cell_list[-1].zero_state(batch_size=beam_batch_size, dtype=self.dtype) beam_decoder_initial_state = tuple(beam_initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state, \ MultiRNNCell(self.beam_decoder_cell_list), beam_decoder_initial_state
def decoder(x, decoder_inputs, keep_prob, sequence_length, memory, memory_length, first_attention): with tf.variable_scope("Decoder") as scope: label_embeddings = tf.get_variable(name="embeddings", shape=[n_classes, embedding_size], dtype=tf.float32) train_inputs_embedded = tf.nn.embedding_lookup(label_embeddings, decoder_inputs) lstm = rnn.LayerNormBasicLSTMCell(n_hidden, dropout_keep_prob=keep_prob) output_l = layers_core.Dense(n_classes, use_bias=True) encoder_state = rnn.LSTMStateTuple(x, x) attention_mechanism = BahdanauAttention( embedding_size, memory=memory, memory_sequence_length=memory_length) cell = AttentionWrapper(lstm, attention_mechanism, output_attention=False) cell_state = cell.zero_state(dtype=tf.float32, batch_size=train_batch_size) cell_state = cell_state.clone(cell_state=encoder_state, attention=first_attention) train_helper = TrainingHelper(train_inputs_embedded, sequence_length) train_decoder = BasicDecoder(cell, train_helper, cell_state, output_layer=output_l) decoder_outputs_train, decoder_state_train, decoder_seq_train = dynamic_decode( train_decoder, impute_finished=True) tiled_inputs = tile_batch(memory, multiplier=beam_width) tiled_sequence_length = tile_batch(memory_length, multiplier=beam_width) tiled_first_attention = tile_batch(first_attention, multiplier=beam_width) attention_mechanism = BahdanauAttention( embedding_size, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) x2 = tile_batch(x, beam_width) encoder_state2 = rnn.LSTMStateTuple(x2, x2) cell = AttentionWrapper(lstm, attention_mechanism, output_attention=False) cell_state = cell.zero_state(dtype=tf.float32, batch_size=test_batch_size * beam_width) cell_state = cell_state.clone(cell_state=encoder_state2, attention=tiled_first_attention) infer_decoder = BeamSearchDecoder(cell, embedding=label_embeddings, start_tokens=[GO] * test_len, end_token=EOS, initial_state=cell_state, beam_width=beam_width, output_layer=output_l) decoder_outputs_infer, decoder_state_infer, decoder_seq_infer = dynamic_decode( infer_decoder, maximum_iterations=4) return decoder_outputs_train, decoder_outputs_infer, decoder_state_infer
def build_dec_cell(self, hidden_size): enc_outputs = self.enc_outputs enc_last_state = self.enc_last_state enc_inputs_length = self.enc_inp_len if self.use_beam_search: self.logger.info("using beam search decoding") enc_outputs = seq2seq.tile_batch(self.enc_outputs, multiplier=self.p.beam_width) enc_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.p.beam_width), self.enc_last_state) enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len, self.p.beam_width) if self.p.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) else: self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=hidden_size, memory=enc_outputs, memory_sequence_length=enc_inputs_length) def attn_dec_input_fn(inputs, attention): if not self.p.attn_input_feeding: return inputs else: _input_layer = Dense(hidden_size, dtype=self.p.dtype, name='attn_input_feeding') return _input_layer(tf.concat([inputs, attention], -1)) self.dec_cell_list = [ self.build_single_cell(hidden_size) for _ in range(self.p.depth) ] if self.p.use_attn: self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.dec_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=hidden_size, cell_input_fn=attn_dec_input_fn, initial_cell_state=enc_last_state[-1], alignment_history=False, name='attention_wrapper') batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width initial_state = [state for state in enc_last_state] if self.p.use_attn: initial_state[-1] = self.dec_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.p.dtype) dec_initial_state = tuple(initial_state) return MultiRNNCell(self.dec_cell_list), dec_initial_state
def setup_decoder_cell(self, config, keep_prob, use_beam_search, init_state, attention_states, attention_lengths): batch_size = get_state_shape(init_state)[0] if use_beam_search: attention_states = tile_batch(attention_states, multiplier=self.beam_width) init_state = nest.map_structure( lambda s: tile_batch(s, self.beam_width), init_state) attention_lengths = tile_batch(attention_lengths, multiplier=self.beam_width) batch_size = batch_size * self.beam_width attention_size = shape(attention_states, -1) attention = getattr(tf.contrib.seq2seq, config.attention_type)( attention_size, attention_states, memory_sequence_length=attention_lengths) def cell_input_fn(inputs, attention): # define cell input function to keep input/output dimension same if not config.use_attention_input_feeding: return inputs attn_project = tf.layers.Dense(config.hidden_size, dtype=tf.float32, name='attn_input_feeding', activation=self.activation) return attn_project(tf.concat([inputs, attention], axis=-1)) cells = _setup_decoder_cell(config, keep_prob) if config.top_attention: # apply attention mechanism only on the top decoder layer cells[-1] = AttentionWrapper( cells[-1], attention_mechanism=attention, name="AttentionWrapper", attention_layer_size=config.hidden_size, alignment_history=use_beam_search, initial_cell_state=init_state[-1], cell_input_fn=cell_input_fn) init_state = [state for state in init_state] init_state[-1] = cells[-1].zero_state(batch_size=batch_size, dtype=tf.float32) init_state = tuple(init_state) cells = MultiRNNCell(cells) else: cells = MultiRNNCell(cells) cells = AttentionWrapper(cells, attention_mechanism=attention, name="AttentionWrapper", attention_layer_size=config.hidden_size, alignment_history=use_beam_search, initial_cell_state=init_state, cell_input_fn=cell_input_fn) init_state = cells.zero_state(batch_size=batch_size, dtype=tf.float32) \ .clone(cell_state=init_state) return cells, init_state
def build_decoder_cell(self): encoder_inputs_length = self.encoder_inputs_length # 编码器输入长度 if self.beam_search: # 是否使用beam search print("use beamsearch decoding..") # 如果使用beam_search,则需要将encoder的输出进行tile_batch # tile_batch的功能是将第一个参数的数据复制multiplier份,在此例中是beam_size份 self.encoder_outputs = tile_batch(self.encoder_outputs, multiplier=self.beam_size) # lambda是一个表达式,在此处相当于是一个关于s的函数 # nest.map_structure(func,structure)将func应用于每一个structure并返回值 # 因为LSTM中有c和h两个structure,所以需要使用nest.map_structrue() self.encoder_state = nest.map_structure( lambda s: tile_batch(s, self.beam_size), self.encoder_state) encoder_inputs_length = tile_batch(encoder_inputs_length, multiplier=self.beam_size) # 定义要使用的attention机制。 # 使用的attention机制是Bahdanau Attention,关于这种attention机制的细节,可以查看论文 # Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. # "Neural Machine Translation by Jointly Learning to Align and Translate." # ICLR 2015. https://arxiv.org/abs/1409.0473 # 这种attention机制还有一种正则化的版本,如果需要在tensorflow中使用,加上参数normalize=True即可 # 关于正则化的细节,可以查看论文 # Tim Salimans, Diederik P. Kingma. # "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks." # https://arxiv.org/abs/1602.07868 attention_mechanism = BahdanauAttention( num_units=self.rnn_size, # 隐层的维度 memory=self.encoder_outputs, # 通常情况下就是encoder的输出 # memory的mask,超过长度数据不计入attention memory_sequence_length=encoder_inputs_length) # 定义decoder阶段要使用的RNNCell,然后为其封装attention wrapper decoder_cell = self.create_rnn_cell() # 定义decoder阶段要使用的RNNCell decoder_cell = AttentionWrapper( # AttentionWrapper()用于封装带attention机制的RNN网络 cell=decoder_cell, # cell参数指明了需要封装的RNN网络 attention_mechanism= attention_mechanism, # attention_mechanism指明了AttentionMechanism的实例 attention_layer_size=self. rnn_size, # attention_layer_size TODO:是attention封装后的RNN状态维度? name='Attention_Wrapper' # name指明了AttentionWrapper的名字 ) # 如果使用beam_seach则batch_size = self.batch_size * self.beam_size batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size # AttentionWrapper.zero_state()的功能是将AttentionWrapper对象0初始化 # AttentionWrapper对象0初始化后可以使用.clone()方法将参数中的状态赋值给AttentionWrapper对象 # 本例中使用encoder阶段的最后一个隐层状态来赋值定义decoder阶段的初始化状态 decoder_initial_state = decoder_cell.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=self.encoder_state) return decoder_cell, decoder_initial_state
def __graph__(self): # encoder encoder_outputs, encoder_state = self.encoder() # decoder with tf.variable_scope('decoder'): encoder_inputs_length = self.encoder_inputs_length if self.beam_search: # 如果使用beam_search,则需要将encoder的输出进行tile_batch,其实就是复制beam_size份。 print("use beamsearch decoding..") encoder_outputs = tile_batch(encoder_outputs, multiplier=self.beam_size) encoder_state = nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_size), encoder_state) encoder_inputs_length = tile_batch(encoder_inputs_length, multiplier=self.beam_size) # 定义要使用的attention机制。 attention_mechanism = BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # 定义decoder阶段要是用的RNNCell,然后为其封装attention wrapper decoder_cell = self.create_rnn_cell() decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper') # 如果使用beam_seach则batch_size = self.batch_size * self.beam_size batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size # 定义decoder阶段的初始化状态,直接使用encoder阶段的最后一个隐层状态进行赋值 decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=encoder_state) output_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=tf.truncated_normal_initializer( mean=0.0, stddev=0.1)) if self.mode == 'train': self.decoder_outputs = self.decoder_train(decoder_cell, decoder_initial_state, output_layer) # loss self.loss = sequence_loss(logits=self.decoder_outputs, targets=self.decoder_targets, weights=self.mask) # summary tf.summary.scalar('loss', self.loss) self.summary_op = tf.summary.merge_all() # optimizer optimizer = tf.train.AdamOptimizer(self.learing_rate) trainable_params = tf.trainable_variables() gradients = tf.gradients(self.loss, trainable_params) clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm) self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params)) elif self.mode == 'decode': self.decoder_predict_decode = self.decoder_decode(decoder_cell, decoder_initial_state, output_layer)
def _create_attention_mechanisms(self, beam_search=False): mechanisms = [] layer_sizes = [] if self._video_memory is not None: if beam_search is True: # TODO potentially broken, please re-check self._video_memory = seq2seq.tile_batch( self._video_memory, multiplier=self._hparams.beam_width) self._video_features_len = seq2seq.tile_batch( self._video_features_len, multiplier=self._hparams.beam_width) for attention_type in self._hparams.attention_type[0]: attention_video = self._create_attention_mechanism( num_units=self._hparams.decoder_units_per_layer[-1], memory=self._video_memory, memory_sequence_length=self._video_features_len, attention_type=attention_type) mechanisms.append(attention_video) layer_sizes.append(self._hparams.decoder_units_per_layer[-1] / 2) if self._audio_memory is not None: if beam_search is True: # TODO potentially broken, please re-check self._audio_memory = seq2seq.tile_batch( self._audio_memory, multiplier=self._hparams.beam_width) self._audio_features_len = seq2seq.tile_batch( self._audio_features_len, multiplier=self._hparams.beam_width) for attention_type in self._hparams.attention_type[1]: attention_audio = self._create_attention_mechanism( num_units=self._hparams.decoder_units_per_layer[-1], memory=self._audio_memory, memory_sequence_length=self._audio_features_len, attention_type=attention_type) mechanisms.append(attention_audio) layer_sizes.append(self._hparams.decoder_units_per_layer[-1] / 2) return mechanisms, layer_sizes
def create_attention_mechanisms(num_units, attention_types, mode, dtype, beam_search=False, beam_width=None, memory=None, memory_len=None, fusion_type=None): r""" Creates a list of attention mechanisms (e.g. seq2seq.BahdanauAttention) and also a list of ints holding the attention projection layer size Args: beam_search: `bool`, whether the beam-search decoding algorithm is used or not """ mechanisms = [] output_attention = None if beam_search is True: memory = seq2seq.tile_batch(memory, multiplier=beam_width) memory_len = seq2seq.tile_batch(memory_len, multiplier=beam_width) for attention_type in attention_types: attention, output_attention = create_attention_mechanism( num_units=num_units, # has to match decoder's state(query) size memory=memory, memory_sequence_length=memory_len, attention_type=attention_type, mode=mode, dtype=dtype, ) mechanisms.append(attention) N = len(attention_types) if fusion_type == 'deep_fusion': attention_layer_sizes = None attention_layers = [ AttentionLayers(units=num_units, dtype=dtype) for _ in range(N) ] elif fusion_type == 'linear_fusion': attention_layer_sizes = [ num_units, ] * N attention_layers = None else: raise Exception('Unknown fusion type') return mechanisms, attention_layers, attention_layer_sizes, output_attention
def _create_decoder_cell(self, encoder_outputs, encoder_state, source_sequence_length): """Build an RNN cell that can be used by decoder.""" # We only make use of encoder_outputs in attention-based models if self.attention_option: raise ValueError("BasicModel doesn't support attention.") cell = model_helper.create_rnn_cell( unit_type=self.unit_type, num_units=self.num_units, num_layers=self.num_decoder_layers, num_residual_layers=self.num_decoder_residual_layers, forget_bias=self.forget_bias, dropout=self.dropout, mode=self.mode) if self.mode == ModeKeys.INFER and self.beam_width > 0: # For beam search, we need to replicate encoder state `beam_width` times decoder_initial_state = seq2seq.tile_batch( encoder_state, multiplier=self.beam_width) else: decoder_initial_state = encoder_state return cell, decoder_initial_state
def _build_decoder_test_beam_search(self): r""" Builds a beam search test decoder """ if self._hparams.enable_attention is True: cells, initial_state = self._add_attention(self._decoder_cells, beam_search=True) else: # does the non-attentive beam decoder need tile_batch ? cells = self._decoder_cells decoder_initial_state_tiled = seq2seq.tile_batch( # guess so ? it compiles without it too self._decoder_initial_state, multiplier=self._hparams.beam_width) initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([self._batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.6, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids
def add_attention( cells, attention_types, num_units, memory, memory_len, mode, batch_size, dtype, beam_search=False, beam_width=None, initial_state=None, write_attention_alignment=False, fusion_type='linear_fusion', ): r""" Wraps the decoder_cells with an AttentionWrapper Args: cells: instances of `RNNCell` beam_search: `bool` flag for beam search decoders batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state Returns: attention_cells: the Attention wrapped decoder cells initial_state: a proper initial state to be used with the returned cells """ attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms( beam_search=beam_search, beam_width=beam_width, memory=memory, memory_len=memory_len, num_units=num_units, attention_types=attention_types, fusion_type=fusion_type, mode=mode, dtype=dtype) if beam_search is True: initial_state = seq2seq.tile_batch(initial_state, multiplier=beam_width) attention_cells = seq2seq.AttentionWrapper( cell=cells, attention_mechanism=attention_mechanisms, attention_layer_size=attention_layer_sizes, # initial_cell_state=decoder_initial_state, alignment_history=write_attention_alignment, output_attention=output_attention, attention_layer=attention_layers, ) attn_zero = attention_cells.zero_state( dtype=dtype, batch_size=batch_size * beam_width if beam_search is True else batch_size) if initial_state is not None: initial_state = attn_zero.clone(cell_state=initial_state) return attention_cells, initial_state
def _build_single_attention_mechanism(memory): if not self._is_training: memory = seq2seq.tile_batch(memory, multiplier=self._beam_width) return seq2seq.BahdanauAttention(self._num_attention_units, memory, memory_sequence_length=None)
def _get_initial_state(initial_state, tiled_initial_state, cell, batch_size, beam_width, dtype): if tiled_initial_state is None: if isinstance(initial_state, AttentionWrapperState): raise ValueError( '`initial_state` must not be an AttentionWrapperState. Use ' 'a plain cell state instead, which will be wrapped into an ' 'AttentionWrapperState automatically.') if initial_state is None: tiled_initial_state = cell.zero_state(batch_size * beam_width, dtype) else: tiled_initial_state = tile_batch(initial_state, multiplier=beam_width) if isinstance(cell, AttentionWrapper) and \ not isinstance(tiled_initial_state, AttentionWrapperState): zero_state = cell.zero_state(batch_size * beam_width, dtype) tiled_initial_state = zero_state.clone(cell_state=tiled_initial_state) return tiled_initial_state
def infer( self, cause_encoder, ): batch_size = tf.shape(self._initial_state)[0] tiled_initial_state = tile_batch(self._initial_state, multiplier=self._beam_width) tiled_initial_state = LSTMStateTuple( tiled_initial_state, tiled_initial_state, last_choice=array_ops.fill([batch_size * self._beam_width], self._SOS)) infer_decoder = MyBeamSearchDecoder(self._lstm_cell, embedding=cause_encoder, start_tokens=tf.fill([batch_size], self._SOS), end_token=self._EOS, initial_state=tiled_initial_state, beam_width=self._beam_width, output_layer=self._project_dense, lookup_table=self._cause_table, length_penalty_weight=0.7, hie=self._hie) cause_output_infer, cause_state_infer, cause_length_infer = dynamic_decode( infer_decoder, parallel_iterations=128, maximum_iterations=self._max_cause_length - 1, scope='decoder') return cause_output_infer, cause_state_infer, cause_length_infer
def model(self): with tf.variable_scope("encoder"): encoder_cell = self._create_rnn_cell() source_embedding = tf.get_variable(name="source_embedding", shape=[self.source_vocab_size, self.embedding_size], initializer=tf.initializers.truncated_normal()) encoder_embedding_inputs = tf.nn.embedding_lookup(source_embedding, self.source_input) encoder_outputs, encoder_states = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=encoder_embedding_inputs, dtype=tf.float32) with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): if self.mode=="test": encoder_states = seq2seq.tile_batch(encoder_states, self.beam_size) decoder_cell = self._create_rnn_cell() decoder_cell = rnn.DropoutWrapper(decoder_cell,output_keep_prob=0.5) if self.mode=="test": batch_size = self.batch_size*self.beam_size else: batch_size = self.batch_size #decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,dtype=tf.float32) output_layer = tf.layers.Dense(units=self.target_vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) target_embedding = tf.get_variable(name="target_embedding", shape=[self.target_vocab_size, self.embedding_size]) if self.mode == "train": self.mask = tf.sequence_mask(self.target_length,self.max_target_length,dtype=tf.float32) del_end = tf.strided_slice(self.target_input,[0,0],[self.batch_size,-1],[1,1]) decoder_input = tf.concat([tf.fill([self.batch_size, 1],2),del_end],axis=1) decoder_input_embedding = tf.nn.embedding_lookup(target_embedding,decoder_input) training_helper = seq2seq.TrainingHelper(inputs=decoder_input_embedding, sequence_length=tf.fill([self.batch_size],self.max_target_length)) training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper, initial_state=encoder_states, output_layer=output_layer) decoder_outputs,_,_ = seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False, impute_finished=True, maximum_iterations=self.max_target_length) self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output) self.decoder_predict_train = tf.argmax(self.decoder_logits_train,axis=-1) self.loss_op = tf.reduce_mean(tf.losses.softmax_cross_entropy( onehot_labels=tf.one_hot(self.target_input, depth=self.target_vocab_size), logits=self.decoder_logits_train, weights=self.mask)) optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) trainable_params = tf.trainable_variables() gradients = tf.gradients(self.loss_op, trainable_params) clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm) self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params)) elif self.mode =="test": start_tokens = tf.fill([self.batch_size], value=2) end_token = 3 inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=target_embedding, start_tokens=start_tokens, end_token=end_token, initial_state=encoder_states, beam_width=self.beam_size, output_layer=output_layer) decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder, maximum_iterations=self.max_target_length) print(decoder_outputs.predicted_ids.get_shape().as_list()) self.decoder_predict_decode = decoder_outputs.predicted_ids[:, :, 0]
def build_attention_wrapper(self, final_cell): self.feedforward_inputs = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["inputs"], multiplier=self.hparams.beam_width), lambda: self.features["inputs"]) self.feedforward_inputs_length = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["length"], multiplier=self.hparams.beam_width), lambda: self.features["length"]) attention_mechanism = self.build_attention_mechanism() return AttentionWrapper(cell=final_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.hparams.hidden_units, cell_input_fn=self._attention_input_feeding, initial_cell_state=self.initial_state[-1] if self.hparams.depth > 1 else self.initial_state)
def test_attention_decoder_given_initial_state(self): """Tests beam search with RNNAttentionDecoder given initial state. """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = tf.constant(seq_length) hparams = { "attention": { "kwargs": { "num_units": self._attention_dim } }, "rnn_cell": { "kwargs": { "num_units": self._cell_dim } } } decoder = tx.modules.AttentionRNNDecoder( vocab_size=self._vocab_size, memory=self._encoder_output, memory_sequence_length=encoder_values_length, hparams=hparams) state = decoder.cell.zero_state(self._batch_size, tf.float32) cell_state = state.cell_state self._test_beam_search(decoder, initial_state=cell_state) tiled_cell_state = tile_batch(cell_state, multiplier=self._beam_width) self._test_beam_search(decoder, tiled_initial_state=tiled_cell_state, initiated=True)
def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build an RNN cell that can be used by decoder.""" # We only make use of encoder_outputs in attention-based models if hparams.attention: raise ValueError("BasicModel doesn't support attention.") cell = model_helper.create_rnn_cell( unit_type=hparams.unit_type, num_units=hparams.num_units, num_layers=self.num_decoder_layers, num_residual_layers=self.num_decoder_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, num_gpus=self.num_gpus, mode=self.mode, single_cell_fn=self.single_cell_fn) # For beam search, we need to replicate encoder infos beam_width times if self.mode == tf.contrib.learn.ModeKeys.INFER and \ hparams.beam_width > 0: decoder_initial_state = seq2seq.tile_batch( encoder_state, multiplier=hparams.beam_width) else: decoder_initial_state = encoder_state return cell, decoder_initial_state
def create_decoder_cell(): cell = tf.contrib.rnn.MultiRNNCell([ utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus)) for _ in range(self.args.dec_layers) ]) if self.args.beam_width > 0 and self.mode == "Infer": dec_start_state = seq2seq.tile_batch(self.encoder_state, self.beam_width) enc_outputs = seq2seq.tile_batch(self.encoder_outputs, self.beam_width) enc_lengths = seq2seq.tile_batch(self.encoder_inputs_length, self.beam_width) else: dec_start_state = self.encoder_state enc_outputs = self.encoder_outputs enc_lengths = self.encoder_inputs_length if self.args.attention: attention_states = enc_outputs attention_mechanism = tf.contrib.seq2seq.LuongAttention( self.args.dec_num_units, attention_states, memory_sequence_length=enc_lengths) decoder_cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism, attention_layer_size=self.args.dec_num_units) if self.args.beam_width > 0 and self.mode == "Infer": initial_state = decoder_cell.zero_state( self.batch_size * self.beam_width, tf.float32) else: initial_state = decoder_cell.zero_state( self.batch_size, tf.float32) initial_state = initial_state.clone(cell_state=dec_start_state) else: decoder_cell = cell initial_state = dec_start_state return decoder_cell, initial_state
def _build_infer(self, config): # infer_decoder/beam_search # skip for flat_baseline tiled_inputs = tile_batch(self.xx_context, multiplier=config.beam_width) tiled_sequence_length = tile_batch(self.x_seq_length, multiplier=config.beam_width) tiled_first_attention = tile_batch(self.first_attention, multiplier=config.beam_width) attention_mechanism = BahdanauAttention(config.decode_size, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) tiled_xx_final = tile_batch(self.xx_final, config.beam_width) encoder_state2 = rnn.LSTMStateTuple(tiled_xx_final, tiled_xx_final) cell = AttentionWrapper(self.lstm, attention_mechanism, output_attention=False) cell_state = cell.zero_state(dtype=tf.float32, batch_size = config.test_batch_size * config.beam_width) cell_state = cell_state.clone(cell_state=encoder_state2, attention=tiled_first_attention) infer_decoder = BeamSearchDecoder(cell, embedding=self.label_embeddings, start_tokens=[config.GO]*config.test_batch_size, end_token=config.EOS, initial_state=cell_state, beam_width=config.beam_width, output_layer=self.output_l) decoder_outputs_infer, decoder_state_infer, decoder_seq_infer = dynamic_decode(infer_decoder, maximum_iterations=config.max_seq_length) self.preds = decoder_outputs_infer.predicted_ids self.scores = decoder_state_infer.log_probs
def _build_decoder_beam_search(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) attention_mechanisms, layer_sizes = self._create_attention_mechanisms( beam_search=True) decoder_initial_state_tiled = seq2seq.tile_batch( self._decoder_initial_state, multiplier=self._hparams.beam_width) if self._hparams.enable_attention is True: attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=decoder_initial_state_tiled, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention) initial_state = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width) initial_state = initial_state.clone( cell_state=decoder_initial_state_tiled) cells = attention_cells else: cells = self._decoder_cells initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.5, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary( states) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids self.beam_search_output = outputs.beam_search_decoder_output
def sample(self, n, max_length=None, z=None, temperature=None, start_inputs=None, beam_width=None, end_token=None): """Overrides BaseLstmDecoder `sample` method to add optional beam search. Args: n: Scalar number of samples to return. max_length: (Optional) Scalar maximum sample length to return. Required if data representation does not include end tokens. z: (Optional) Latent vectors to sample from. Required if model is conditional. Sized `[n, z_size]`. temperature: (Optional) The softmax temperature to use when not doing beam search. Defaults to 1.0. Ignored when `beam_width` is provided. start_inputs: (Optional) Initial inputs to use for batch. Sized `[n, output_depth]`. beam_width: (Optional) Width of beam to use for beam search. Beam search is disabled if not provided. end_token: (Optional) Scalar token signaling the end of the sequence to use for early stopping. Returns: samples: Sampled sequences. Sized `[n, max_length, output_depth]`. Raises: ValueError: If `z` is provided and its first dimension does not equal `n`. """ if beam_width is None: end_fn = (None if end_token is None else lambda x: tf.equal(tf.argmax(x, axis=-1), end_token)) return super(CategoricalLstmDecoder, self).sample( n, max_length, z, temperature, start_inputs, end_fn) # If `end_token` is not given, use an impossible value. end_token = self._output_depth if end_token is None else end_token if z is not None and z.shape[0].value != n: raise ValueError( '`z` must have a first dimension that equals `n` when given. ' 'Got: %d vs %d' % (z.shape[0].value, n)) if temperature is not None: tf.logging.warning('`temperature` is ignored when using beam search.') # Use a dummy Z in unconditional case. z = tf.zeros((n, 0), tf.float32) if z is None else z # If not given, start with dummy `-1` token and replace with zero vectors in # `embedding_fn`. start_tokens = ( tf.argmax(start_inputs, axis=-1, output_type=tf.int32) if start_inputs is not None else -1 * tf.ones([n], dtype=tf.int32)) initial_state = initial_cell_state_from_embedding( self._dec_cell, z, name='decoder/z_to_initial_state') beam_initial_state = seq2seq.tile_batch( initial_state, multiplier=beam_width) # Tile `z` across beams. beam_z = tf.tile(tf.expand_dims(z, 1), [1, beam_width, 1]) def embedding_fn(tokens): # If tokens are the start_tokens (negative), replace with zero vectors. next_inputs = tf.cond( tf.less(tokens[0, 0], 0), lambda: tf.zeros([n, beam_width, self._output_depth]), lambda: tf.one_hot(tokens, self._output_depth)) # Concatenate `z` to next inputs. next_inputs = tf.concat([next_inputs, beam_z], axis=-1) return next_inputs decoder = seq2seq.BeamSearchDecoder( self._dec_cell, embedding_fn, start_tokens, end_token, beam_initial_state, beam_width, output_layer=self._output_layer, length_penalty_weight=0.0) final_output, _, _ = seq2seq.dynamic_decode( decoder, maximum_iterations=max_length, swap_memory=True, scope='decoder') return tf.one_hot( final_output.predicted_ids[:, :, 0], self._output_depth)
def decode_model(self, train_mode=True): with tf.variable_scope('decoder') as scope: encoder_outputs = self.encoder_outputs source_lengths = self.source_lengths encoder_state = self.encoder_state global_step = self.global_step projection_layer = self.projection_layer tgt_sos_id = self.tgt_sos_id tgt_eos_id = self.tgt_eos_id target = self.target target_lookup = self.tgt_lookup target_lengths = self.target_lengths learning_rate = self.learning_rate target_embed = self.tgt_embed if train_mode == True: reuse = False else: reuse = False # Prepare b_size = tf.size(source_lengths) if train_mode == False: b_size_t = tf.to_int32(b_size * self.beam_width) encoder_outputs = tile_batch(encoder_outputs, self.beam_width) #tile_batch(encoder_outputs, beam_width) encoder_state = tile_batch(encoder_state, self.beam_width) source_lengths = tile_batch(source_lengths, self.beam_width) else: b_size_t = b_size # Create decoder_cell rnn_layer = [self.get_cell() for i in range(self.rnn_layer_depth)] # Create attention cell (top of rnn layer) multi_rnn = self.wrap_multi_rnn(rnn_layer) attention = self.wrap_attention(multi_rnn, encoder_outputs, source_lengths) attention = tf.contrib.rnn.DeviceWrapper(attention, '/cpu:0') # sync cell state with encoder #decoder_state = [enc for enc in encoder_state] decoder_state = attention.zero_state(batch_size=b_size_t, dtype=tf.float32).clone( cell_state=encoder_state) if train_mode == True: # define decoder decode_helper = tf.contrib.seq2seq.TrainingHelper( target_lookup, target_lengths) decoder = tf.contrib.seq2seq.BasicDecoder( attention, decode_helper, decoder_state, output_layer=projection_layer) outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, scope=scope) logits = tf.identity(outputs.rnn_output) # Train Loss weights = tf.to_float(tf.concat([ tf.expand_dims(tf.fill([b_size], tf.constant(True)), 1), tf.not_equal(target[:, :-1], tgt_eos_id)] , 1)) loss = tf.contrib.seq2seq.sequence_loss(logits, target, weights=weights) # Optimize params = tf.trainable_variables() gradients = tf.gradients(loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0) optimizer = tf.train.AdamOptimizer(learning_rate) opt = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step) return opt, loss, global_step else: infer_decoder = tf.contrib.seq2seq.BeamSearchDecoder( attention, target_embed, tf.fill([b_size], tf.to_int32(tgt_sos_id)), tf.to_int32(tgt_eos_id), decoder_state, self.beam_width, output_layer=projection_layer) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( infer_decoder, maximum_iterations=tf.round(tf.reduce_max(source_lengths) * 2), scope=scope) infer_result = infer_outputs.predicted_ids return infer_result, target, global_step, target
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, target_dict_dim, is_generating, beam_size, max_generation_length): src_word_idx = tf.placeholder(tf.int32, shape=[None, None]) src_sequence_length = tf.placeholder(tf.int32, shape=[None, ]) src_embedding_weights = tf.get_variable("source_word_embeddings", [source_dict_dim, embedding_dim]) src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx) src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size) src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size) # no peephole encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw=src_forward_cell, cell_bw=src_reversed_cell, inputs=src_embedding, sequence_length=src_sequence_length, dtype=tf.float32) # concat the forward outputs and backward outputs encoded_vec = tf.concat(encoder_outputs, axis=2) # project the encoder outputs to size of decoder lstm encoded_proj = tf.contrib.layers.fully_connected( inputs=tf.reshape( encoded_vec, shape=[-1, embedding_dim * 2]), num_outputs=decoder_size, activation_fn=None, biases_initializer=None) encoded_proj_reshape = tf.reshape( encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size]) # get init state for decoder lstm's H backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1]) decoder_boot = tf.contrib.layers.fully_connected( inputs=tf.reshape( backword_first, shape=[-1, embedding_dim]), num_outputs=decoder_size, activation_fn=tf.nn.tanh, biases_initializer=None) # prepare the initial state for decoder lstm cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32) initial_state = LSTMStateTuple(cell_init, decoder_boot) # create decoder lstm cell decoder_cell = LSTMCellWithSimpleAttention( decoder_size, encoded_vec if not is_generating else seq2seq.tile_batch(encoded_vec, beam_size), encoded_proj_reshape if not is_generating else seq2seq.tile_batch(encoded_proj_reshape, beam_size), src_sequence_length if not is_generating else seq2seq.tile_batch(src_sequence_length, beam_size), forget_bias=0.0) output_layer = Dense(target_dict_dim, name='output_projection') if not is_generating: trg_word_idx = tf.placeholder(tf.int32, shape=[None, None]) trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ]) trg_embedding_weights = tf.get_variable( "target_word_embeddings", [target_dict_dim, embedding_dim]) trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights, trg_word_idx) training_helper = seq2seq.TrainingHelper( inputs=trg_embedding, sequence_length=trg_sequence_length, time_major=False, name='training_helper') training_decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=training_helper, initial_state=initial_state, output_layer=output_layer) # get the max length of target sequence max_decoder_length = tf.reduce_max(trg_sequence_length) decoder_outputs_train, _, _ = seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_decoder_length) decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output) decoder_pred_train = tf.argmax( decoder_logits_train, axis=-1, name='decoder_pred_train') masks = tf.sequence_mask( lengths=trg_sequence_length, maxlen=max_decoder_length, dtype=tf.float32, name='masks') # place holder of label sequence lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None]) # compute the loss loss = seq2seq.sequence_loss( logits=decoder_logits_train, targets=lbl_word_idx, weights=masks, average_across_timesteps=True, average_across_batch=True) # return feeding list and loss operator return { 'src_word_idx': src_word_idx, 'src_sequence_length': src_sequence_length, 'trg_word_idx': trg_word_idx, 'trg_sequence_length': trg_sequence_length, 'lbl_word_idx': lbl_word_idx }, loss else: start_tokens = tf.ones([tf.shape(src_word_idx)[0], ], tf.int32) * START_TOKEN_IDX # share the same embedding weights with target word trg_embedding_weights = tf.get_variable( "target_word_embeddings", [target_dict_dim, embedding_dim]) inference_decoder = beam_search_decoder.BeamSearchDecoder( cell=decoder_cell, embedding=lambda tokens: tf.nn.embedding_lookup(trg_embedding_weights, tokens), start_tokens=start_tokens, end_token=END_TOKEN_IDX, initial_state=tf.nn.rnn_cell.LSTMStateTuple( tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size), tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)), beam_width=beam_size, output_layer=output_layer) decoder_outputs_decode, _, _ = seq2seq.dynamic_decode( decoder=inference_decoder, output_time_major=False, #impute_finished=True,# error occurs maximum_iterations=max_generation_length) predicted_ids = decoder_outputs_decode.predicted_ids return { 'src_word_idx': src_word_idx, 'src_sequence_length': src_sequence_length }, predicted_ids