def train_helper(): start_ids = tf.fill([batch_size, 1], self._output_sos_id) decoder_input_ids = tf.concat([start_ids, self.output_data], 1) decoder_inputs = self._output_onehot(decoder_input_ids) return seq2seq.ScheduledEmbeddingTrainingHelper( decoder_inputs, self.output_lengths, self._output_onehot, sampling_probability)
def _basic_decoder_train(self): r""" Builds the standard teacher-forcing training decoder with sampling from previous predictions. """ helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._decoder_train_inputs, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) if self._hparams.enable_attention is True: cells, initial_state = self._add_attention(self._decoder_cells) else: cells = self._decoder_cells initial_state = self._decoder_initial_state decoder_train = seq2seq.BasicDecoder( cell=cells, helper=helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) outputs, fstate, fseqlen = seq2seq.dynamic_decode( decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) return outputs, fstate, fseqlen
def get_DecoderHelper(embedding_lookup, seq_lengths, token_dim, gt_tokens=None, unroll_type='teacher_forcing'): if unroll_type == 'teacher_forcing': if gt_tokens is None: raise ValueError('teacher_forcing requires gt_tokens') embedding = embedding_lookup(gt_tokens) helper = seq2seq.TrainingHelper(embedding, seq_lengths) elif unroll_type == 'scheduled_sampling': if gt_tokens is None: raise ValueError('scheduled_sampling requires gt_tokens') embedding = embedding_lookup(gt_tokens) # sample_prob 1.0: always sample from ground truth # sample_prob 0.0: always sample from prediction helper = seq2seq.ScheduledEmbeddingTrainingHelper( embedding, seq_lengths, embedding_lookup, 1.0 - self.sample_prob, seed=None, scheduling_seed=None) elif unroll_type == 'greedy': # during evaluation, we perform greedy unrolling. start_token = tf.zeros([self.batch_size], dtype=tf.int32) + token_dim end_token = token_dim - 1 helper = seq2seq.GreedyEmbeddingHelper(embedding_lookup, start_token, end_token) else: raise ValueError('Unknown unroll type') return helper
def get_DecoderHelper(embedding_lookup, seq_lengths, token_dim, gt_tokens=None, sequence_type='program', unroll_type='teacher_forcing'): if unroll_type == 'teacher_forcing': if gt_tokens is None: raise ValueError('teacher_forcing requires gt_tokens') embedding = embedding_lookup(gt_tokens) helper = seq2seq.TrainingHelper(embedding, seq_lengths) elif unroll_type == 'scheduled_sampling': if gt_tokens is None: raise ValueError('scheduled_sampling requires gt_tokens') embedding = embedding_lookup(gt_tokens) # sample_prob 1.0: always sample from ground truth # sample_prob 0.0: always sample from prediction helper = seq2seq.ScheduledEmbeddingTrainingHelper( embedding, seq_lengths, embedding_lookup, 1.0 - self.sample_prob, seed=None, scheduling_seed=None) elif unroll_type == 'greedy': # during evaluation, we perform greedy unrolling. start_token = tf.zeros([self.batch_size], dtype=tf.int32) + \ token_dim if sequence_type == 'program': end_token = self.vocab.token2int['m)'] elif sequence_type == 'action': end_token = token_dim - 1 else: # Hack to have no end token, greater than number of perceptions end_token = 11 helper = seq2seq.GreedyEmbeddingHelper( embedding_lookup, start_token, end_token) else: raise ValueError('Unknown unroll type') return helper
def _basic_decoder_train(self): r""" Builds the standard teacher-forcing training decoder with sampling from previous predictions. """ helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._decoder_train_inputs, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) # christian_fun = lambda logits: tf.math.top_k(logits, 3).indices # # helper_train = seq2seq.ScheduledOutputTrainingHelper( # inputs=self._decoder_train_inputs, # sequence_length=self._labels_len, # sampling_probability=self._sampling_probability_outputs, # ) if self._hparams.enable_attention is True: cells, initial_state = add_attention( cells=self._decoder_cells, attention_types=self._hparams.attention_type[1], num_units=self._hparams.decoder_units_per_layer[-1], memory=self._encoder_memory, memory_len=self._encoder_features_len, initial_state=self._decoder_initial_state, batch_size=self._batch_size, mode=self._mode, dtype=self._hparams.dtype, fusion_type='linear_fusion', write_attention_alignment=False, # we are in train mode ) else: cells = self._decoder_cells initial_state = self._decoder_initial_state decoder_train = seq2seq.BasicDecoder( cell=cells, helper=helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) outputs, fstate, fseqlen = seq2seq.dynamic_decode( decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) return outputs, fstate, fseqlen
def _closure(word_embeddings): tf.summary.scalar('decoder_sampling_p', self._decoder_sampling_p) decoder_targets_embedded = tf.nn.embedding_lookup( word_embeddings, add_pad_eos(self.targets, self.targets_length)) return seq2seq.ScheduledEmbeddingTrainingHelper( inputs=decoder_targets_embedded, sequence_length=(self.targets_length + 2), embedding=word_embeddings, sampling_probability=self._decoder_sampling_p)
def _build_decoder_train(self): self._labels_embedded = tf.nn.embedding_lookup(self._embedding_matrix, self._labels_padded_GO) self._helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._labels_embedded, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) if self._hparams.enable_attention is True: attention_mechanisms, layer_sizes = self._create_attention_mechanisms( ) attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=self._decoder_initial_state, alignment_history=False, output_attention=self._output_attention, ) batch_size, _ = tf.unstack(tf.shape(self._labels)) attn_zero = attention_cells.zero_state(dtype=self._hparams.dtype, batch_size=batch_size) initial_state = attn_zero.clone( cell_state=self._decoder_initial_state) cells = attention_cells else: cells = self._decoder_cells initial_state = self._decoder_initial_state self._decoder_train = seq2seq.BasicDecoder( cell=cells, helper=self._helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) self._basic_decoder_train_outputs, self._final_states, self._final_seq_lens = seq2seq.dynamic_decode( self._decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) self._logits = self._basic_decoder_train_outputs.rnn_output
def _build_helper(self, batch_size, embeddings, inputs, inputs_length, mode, hparams, decoder_hparams): """Builds a helper instance for BasicDecoder.""" # Auxiliary decoding mode at training time. if decoder_hparams.auxiliary: start_tokens = tf.fill([batch_size], text_encoder.PAD_ID) # helper = helpers.FixedContinuousEmbeddingHelper( # embedding=embeddings, # start_tokens=start_tokens, # end_token=text_encoder.EOS_ID, # num_steps=hparams.aux_decode_length) helper = contrib_seq2seq.SampleEmbeddingHelper( embedding=embeddings, start_tokens=start_tokens, end_token=text_encoder.EOS_ID, softmax_temperature=None) # Continuous decoding. elif hparams.decoder_continuous: # Scheduled mixing. if mode == tf.estimator.ModeKeys.TRAIN and hparams.scheduled_training: helper = helpers.ScheduledContinuousEmbeddingTrainingHelper( inputs=inputs, sequence_length=inputs_length, mixing_concentration=hparams.scheduled_mixing_concentration ) # Pure continuous decoding (hard to train!). elif mode == tf.estimator.ModeKeys.TRAIN: helper = helpers.ContinuousEmbeddingTrainingHelper( inputs=inputs, sequence_length=inputs_length) # EVAL and PREDICT expect teacher forcing behavior. else: helper = contrib_seq2seq.TrainingHelper( inputs=inputs, sequence_length=inputs_length) # Standard decoding. else: # Scheduled sampling. if mode == tf.estimator.ModeKeys.TRAIN and hparams.scheduled_training: helper = contrib_seq2seq.ScheduledEmbeddingTrainingHelper( inputs=inputs, sequence_length=inputs_length, embedding=embeddings, sampling_probability=hparams.scheduled_sampling_probability ) # Teacher forcing (also for EVAL and PREDICT). else: helper = contrib_seq2seq.TrainingHelper( inputs=inputs, sequence_length=inputs_length) return helper
def _build_decoder_train(self): self._decoder_train_inputs = tf.nn.embedding_lookup( self._embedding_matrix, self._labels_padded_GO) if self._mode == 'train': sampler = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._decoder_train_inputs, sequence_length=self._labels_length, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) else: sampler = seq2seq.TrainingHelper( inputs=self._decoder_train_inputs, sequence_length=self._labels_length, ) cells = self._decoder_cells decoder_train = seq2seq.BasicDecoder( cell=cells, helper=sampler, initial_state=self._decoder_initial_state, output_layer=self._dense_layer, ) outputs, _, _ = seq2seq.dynamic_decode( decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) logits = outputs.rnn_output self.decoder_train_outputs = logits self.average_log_likelihoods = self._compute_likelihood(logits) print('')
def build_train_decoder(self): self.decoder_inputs_embedded = tf.nn.embedding_lookup( params=self.embedding, ids=self.decoder_inputs_train) if self.train_mode == 'ground_truth': # inputs:对应Decoder框架图中的embedded_input,time_major = False的时候,inputs的shape就是[ # batch_size, sequence_length, embedding_size] ,time_major = True时,inputs的shape为[ # sequence_length, batch_size, embedding_size] # sequence_length:这个文档写的太简略了,不过在源码中可以看出指的是当前batch中每个序列的长度(self._batch_size = array_ops.size(sequence_length))。 # time_major:决定inputs # Tensor前两个dim表示的含义 # name:如文档所述 training_helper = seq2seq.TrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, time_major=False, name='training_helper') elif self.train_mode == 'scheduled_sampling': training_helper = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, embedding=lambda inputs: tf.nn.embedding_lookup( self.embedding, inputs), sampling_probability=self.sampling_probability, name='scheduled_embedding_training_helper') else: raise NotImplementedError( 'Train mode: {} is not yet implemented'.format( self.train_mode)) training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=self.output_layer) max_decoder_length = tf.reduce_max( self.decoder_inputs_length_train) #取得序列中的最大值 self.decoder_outputs_train, self.decoder_last_state_train, self.decoder_outputs_length_train = seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_decoder_length) # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id) # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案 # 在计算图内部创建 send / recv节点来引用或复制变量的,最主要的用途就是更好的控制在不同设备间传递变量的值 self.decoder_logits_train = tf.identity( self.decoder_outputs_train.rnn_output) cnn_pre = self.decoder_logits_train[-1] print(cnn_pre.get_shape()) # self.decoder_logits_train[-1] = tf.nn.conv1d( self.decoder_logits_train, self.kernel,stride=1,padding="SAME") # print(self.decoder_logits_train.get_shape()) self.decoder_logits_train = tf.nn.conv1d(self.decoder_logits_train, self.kernel, stride=1, padding="SAME") # tf.argmax 给出某个tensor对象在某一维上的其数据最大值所在的索引值 self.decoder_pred_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pred_train') # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1] masks = tf.sequence_mask(lengths=self.decoder_inputs_length_train, maxlen=max_decoder_length, dtype=self.dtype, name='masks') # 用来直接计算序列的损失函数 # Internally calls 'nn_ops.sparse_softmax_cross_entropy_with_logits' by default # logits:尺寸[batch_size, sequence_length, num_decoder_symbols] # targets:尺寸[batch_size, sequence_length],不用做one_hot。 # weights:[batch_size, sequence_length],即mask,滤去padding的loss计算,使loss计算更准确。 self.loss = seq2seq.sequence_loss(logits=self.decoder_logits_train, targets=self.decoder_targets_train, weights=masks, average_across_timesteps=True, average_across_batch=True) # Training summary for the current batch_loss #用来显示标量信息 tf.summary.scalar('loss', self.loss) # Contruct graphs for minimizing loss self.init_optimizer()
def _build_sentence_decoder(self, inputs, context_encoder_outputs, sentence_encoder_final_states, sentence_encoder_outputs): batch_size = self._batch_size num_sentence = self._num_sentence word_embedding = model_helper.create_word_embedding( num_vocab=self.hparams.num_vocab, embedding_dim=self.hparams.word_embedding_dim, name='decoder/word_embedding', pretrained_word_matrix=self.hparams.pretrained_word_path) # tile_batch in inference mode beam_width = self.hparams.beam_width if self.mode == tf.contrib.learn.ModeKeys.INFER: # only decode last timestep if 'lstm' in self.hparams.rnn_cell_type.lower(): batched_sentence_encoder_states = [] for encoder_state in sentence_encoder_final_states: target_shape = tf.stack([batch_size, num_sentence, -1]) c = s2s.tile_batch( tf.reshape(encoder_state.c, target_shape)[:, -1, :], beam_width) h = s2s.tile_batch( tf.reshape(encoder_state.h, target_shape)[:, -1, :], beam_width) batched_sentence_encoder_states.append( tf.contrib.rnn.LSTMStateTuple(c=c, h=h)) else: batched_sentence_encoder_states = [ s2s.tile_batch( tf.reshape(encoder_state, tf.stack([batch_size, num_sentence, -1]))[:, -1, :], beam_width) for encoder_state in sentence_encoder_final_states ] sentence_encoder_final_states = tuple( batched_sentence_encoder_states) sentence_encoder_outputs = s2s.tile_batch( tf.reshape( sentence_encoder_outputs, tf.stack([ batch_size, num_sentence, -1, self.hparams.num_rnn_units ]))[:, -1, :, :], beam_width) source_lengths = s2s.tile_batch(inputs.src_lengths[:, -1], beam_width) context_encoder_outputs = tf.reshape( context_encoder_outputs, tf.stack( [batch_size, num_sentence, self.hparams.num_rnn_units]))[:, -1, :] context_encoder_outputs = tf.tile( tf.expand_dims(context_encoder_outputs, axis=1), [1, beam_width, 1]) effective_batch_size = self._batch_size * beam_width else: source_lengths = tf.reshape(inputs.src_lengths, [-1]) context_encoder_outputs.set_shape( [None, self.hparams.num_rnn_units]) effective_batch_size = self._batch_size * self._num_sentence # Current strategy: No residual layers at decoder attention_mechanism = model_helper.create_attention_mechanism( attention_option=self.hparams.attention_type, num_units=self.hparams.num_rnn_units, memory=sentence_encoder_outputs, source_length=source_lengths) decoder_cell = s2s.AttentionWrapper( model_helper.create_rnn_cell( cell_type=self.hparams.rnn_cell_type, num_layers=self.hparams.num_rnn_layers, num_units=self.hparams.num_rnn_units, dropout_keep_prob=self._dropout_keep_prob, num_residual_layers=0), attention_mechanism, attention_layer_size=self.hparams.num_rnn_units, alignment_history=False, name="attention") decoder_initial_state = decoder_cell.zero_state( effective_batch_size, tf.float32) decoder_initial_state = decoder_initial_state.clone( cell_state=sentence_encoder_final_states) with tf.variable_scope('output_projection'): output_layer = layers_core.Dense(self.hparams.num_vocab, name="output_projection") self.output_layer = output_layer if self.mode in { tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL }: decoder_input_tokens = tf.reshape( inputs.targets_in, tf.stack([batch_size * num_sentence, -1])) decoder_inputs = tf.nn.embedding_lookup(word_embedding, decoder_input_tokens) target_lengths = tf.reshape(inputs.tgt_lengths, [-1]) if self.mode == tf.contrib.learn.ModeKeys.TRAIN and False: sampling_probability = 1.0 - tf.train.exponential_decay( 1.0, self.global_step, self.hparams.scheduled_sampling_decay_steps, self.hparams.scheduled_sampling_decay_rate, staircase=True, name='scheduled_sampling_prob') helper = s2s.ScheduledEmbeddingTrainingHelper( inputs=decoder_inputs, sequence_length=target_lengths, embedding=word_embedding, sampling_probability=sampling_probability, name='scheduled_sampling_helper') else: helper = s2s.TrainingHelper( inputs=decoder_inputs, sequence_length=target_lengths, name='training_helper', ) decoder = s2s.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=None) final_outputs, final_state, _ = dynamic_decode_with_concat( decoder, context_encoder_outputs, swap_memory=True) logits = final_outputs.rnn_output sample_id = final_outputs.sample_id else: sos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.SOS)), tf.int32) eos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.EOS)), tf.int32) sos_ids = tf.fill([batch_size], sos_id) decoder = s2s.BeamSearchDecoder( cell=decoder_cell, embedding=word_embedding, start_tokens=sos_ids, end_token=eos_id, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer) final_outputs, final_state, _ = dynamic_decode_with_concat( decoder, context_encoder_outputs, maximum_iterations=self.hparams.target_max_length, swap_memory=True) logits = final_outputs.beam_search_decoder_output.scores sample_id = final_outputs.predicted_ids return logits, final_state, sample_id
def build_train_decoder(self): self.decoder_inputs_embedded = tf.nn.embedding_lookup( params=self.embedding, ids=self.decoder_inputs_train) if self.train_mode == 'ground_truth': training_helper = seq2seq.TrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, time_major=False, name='training_helper') elif self.train_mode == 'scheduled_sampling': training_helper = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, embedding=lambda inputs: tf.nn.embedding_lookup( self.embedding, inputs), sampling_probability=self.sampling_probability, name='scheduled_embedding_training_helper') else: raise NotImplementedError( 'Train mode: {} is not yet implemented'.format( self.train_mode)) training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=self.output_layer) max_decoder_length = tf.reduce_max(self.decoder_inputs_length_train) self.decoder_outputs_train, self.decoder_last_state_train, self.decoder_outputs_length_train = seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_decoder_length) # NOTE(sdsuo): Not sure why this is necessary self.decoder_logits_train = tf.identity( self.decoder_outputs_train.rnn_output) # Use argmax to extract decoder symbols to emit self.decoder_pred_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pred_train') # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1] masks = tf.sequence_mask(lengths=self.decoder_inputs_length_train, maxlen=max_decoder_length, dtype=self.dtype, name='masks') # Computes per word average cross-entropy over a batch # Internally calls 'nn_ops.sparse_softmax_cross_entropy_with_logits' by default self.loss = seq2seq.sequence_loss(logits=self.decoder_logits_train, targets=self.decoder_targets_train, weights=masks, average_across_timesteps=True, average_across_batch=True) # Training summary for the current batch_loss tf.summary.scalar('loss', self.loss) # Contruct graphs for minimizing loss self.init_optimizer()
def _basic_decoder_train(self): r""" Builds the standard teacher-forcing training decoder with sampling from previous predictions. """ helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._decoder_train_inputs, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) # christian_fun = lambda logits: tf.math.top_k(logits, 3).indices # # helper_train = seq2seq.ScheduledOutputTrainingHelper( # inputs=self._decoder_train_inputs, # sequence_length=self._labels_len, # sampling_probability=self._sampling_probability_outputs, # ) if self._hparams.enable_attention is True: print('_encoder_memory', self._encoder_memory) print(self._decoder_initial_state) cells, initial_state = add_attention( cell_type=self._cell_type, cells=self._decoder_cells, attention_types=self._hparams.attention_type[1], num_units=self._hparams.decoder_units_per_layer[-1], memory=self._encoder_memory, memory_len=self._encoder_features_len, initial_state=self._decoder_initial_state, batch_size=self._batch_size, mode=self._mode, dtype=self._hparams.dtype, fusion_type='linear_fusion', write_attention_alignment=False, # we are in train mode ) else: cells = self._decoder_cells initial_state = self._decoder_initial_state print('Decoder_unimodal_cells', cells) print('Decoder_unimodal_initial_state', initial_state) decoder_train = seq2seq.BasicDecoder( cell=cells, helper=helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) out = seq2seq.dynamic_decode( decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) outputs, fstate, fseqlen = out if "skip" in self._cell_type: outputs, updated_states = outputs print("DecoderUnimodal_updated_states", updated_states) cost_per_sample = self._hparams.cost_per_sample[2] budget_loss = tf.reduce_mean( tf.reduce_sum(cost_per_sample * updated_states, 1), 0) meanUpdates = tf.reduce_mean(tf.reduce_sum(updated_states, 1), 0) self.skip_infos = SkipInfoTuple(updated_states, meanUpdates, budget_loss) return outputs, fstate, fseqlen
def build_graph(self): enc_outputs, enc_state = tf.nn.dynamic_rnn( cell=self.cell(), inputs=self.enc_input, sequence_length=self.lengths, dtype=tf.float32) # Replicate the top-most encoder state for starting state of all layers in the decoder dec_start_state = tuple(enc_state[-1] for _ in range(self.layers)) output = Dense( self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(stddev=0.1)) # Training decoder: scheduled sampling et al. with tf.variable_scope("decode"): cell = self.decoder_cell(enc_outputs, self.lengths) init_state = cell.zero_state(self.batch_size, tf.float32) init_state = init_state.clone(cell_state=dec_start_state) train_helper = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self.dec_input, sequence_length=self.lengths, embedding=self.dec_embed, sampling_probability=0.1) train_decoder = seq2seq.BasicDecoder(cell=cell, helper=train_helper, initial_state=init_state, output_layer=output) train_output, _, train_lengths = seq2seq.dynamic_decode( decoder=train_decoder, maximum_iterations=self.maxLength) dec_start_state = seq2seq.tile_batch(dec_start_state, self.beam_width) enc_outputs = seq2seq.tile_batch(enc_outputs, self.beam_width) lengths = seq2seq.tile_batch(self.lengths, self.beam_width) with tf.variable_scope("decode", reuse=True): cell = self.decoder_cell(enc_outputs, lengths) init_state = cell.zero_state(self.batch_size * self.beam_width, tf.float32) init_state = init_state.clone(cell_state=dec_start_state) test_decoder = seq2seq.BeamSearchDecoder( cell=cell, embedding=self.dec_embed, start_tokens=tf.ones(self.batch_size, dtype=tf.int32) * self.tokens['GO'], end_token=self.tokens['EOS'], initial_state=init_state, beam_width=self.beam_width, output_layer=output) test_output, _, test_lengths = seq2seq.dynamic_decode( decoder=test_decoder, maximum_iterations=self.maxLength) # Create train op mask = tf.sequence_mask(train_lengths + 1, self.maxLength - 1, dtype=tf.float32) self.cost = seq2seq.sequence_loss(train_output.rnn_output, self.add_eos[:, :-1], mask) self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.cost) # Create test error rate op predicts = self.to_sparse(test_output.predicted_ids[:, :, 0], test_lengths[:, 0] - 1) labels = self.to_sparse(self.add_eos, self.lengths) self.error_rate = tf.reduce_mean(tf.edit_distance(predicts, labels))
def build_train_decoder(self): self.decoder_inputs_embedded = tf.nn.embedding_lookup( params=self.embedding, ids=self.decoder_inputs_train) # 训练阶段,使用TrainingHelper+BasicDecoder的组合 if self.train_mode == 'ground_truth': training_helper = seq2seq.TrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, time_major=False, name='training_helper') elif self.train_mode == 'scheduled_sampling': training_helper = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, embedding=lambda inputs: tf.nn.embedding_lookup( self.embedding, inputs), sampling_probability=self.sampling_probability, name='scheduled_embedding_training_helper') else: raise NotImplementedError( 'Train mode: {} is not yet implemented'.format( self.train_mode)) training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=self.output_layer) max_decoder_length = tf.reduce_max(self.decoder_inputs_length_train) # 调用dynamic_decode进行解码,decoder_outputs_train是一个namedtuple,里面包含两项(rnn_outputs, sample_id) # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案 self.decoder_outputs_train, self.decoder_last_state_train, self.decoder_outputs_length_train = seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_decoder_length) # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op self.decoder_logits_train = tf.identity( self.decoder_outputs_train.rnn_output) # Use argmax to extract decoder symbols to emit self.decoder_pred_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pred_train') # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1] masks = tf.sequence_mask(lengths=self.decoder_inputs_length_train, maxlen=max_decoder_length, dtype=self.dtype, name='masks') # Computes per word average cross-entropy over a batch # Internally calls 'nn_ops.sparse_softmax_cross_entropy_with_logits' by default # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志 self.loss = seq2seq.sequence_loss(logits=self.decoder_logits_train, targets=self.decoder_targets_train, weights=masks, average_across_timesteps=True, average_across_batch=True) # Training summary for the current batch_loss tf.summary.scalar('loss', self.loss) # Contruct graphs for minimizing loss self.init_optimizer()
def __init__(self, config, batch_size, decoder_input, latent_variables, embedding, output_len, vocab_size, go_idx, eos_idx, is_training=True, ru=False): self.config = config with tf.name_scope("decoder_input"): self.batch_size = batch_size self.decoder_input = decoder_input self.latent_variables = latent_variables self.embedding = embedding self.output_len = output_len self.vocab_size = vocab_size self.go_idx = go_idx self.eos_idx = eos_idx self.is_training = is_training with tf.variable_scope("Length_Control"): if self.config.LEN_EMB_SIZE > 0: self.len_embeddings = tf.get_variable( name="len_embeddings", shape=(self.config.NUM_LEN_EMB, self.config.LEN_EMB_SIZE), dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.1)) def create_cell(): if self.config.RNN_CELL == 'lnlstm': cell = tf.contrib.rnn.LayerNormBasicLSTMCell( self.config.DEC_RNN_SIZE) elif self.config.RNN_CELL == 'lstm': cell = tf.contrib.rnn.BasicLSTMCell(self.config.DEC_RNN_SIZE) elif self.config.RNN_CELL == 'gru': cell = tf.contrib.rnn.GRUCell(self.config.DEC_RNN_SIZE) else: logger.error('rnn_cell {} not supported'.format( self.config.RNN_CELL)) if self.is_training: cell = tf.nn.rnn_cell.DropoutWrapper( cell, output_keep_prob=self.config.DROPOUT_KEEP) return cell cell = tf.nn.rnn_cell.MultiRNNCell([create_cell() for _ in range(2)]) projection_layer = Dense(self.vocab_size) projection_layer.build(self.config.DEC_RNN_SIZE) self.beam_ids = self.get_beam_ids(cell, projection_layer) if self.config.LEN_EMB_SIZE > 0: initial_state = cell.zero_state(self.batch_size, dtype=tf.float32) cell = LenControlWrapper(cell, self.output_len, self.len_embeddings, initial_cell_state=initial_state) initial_state = cell.zero_state(self.batch_size, dtype=tf.float32) cell = AlignmentWrapper(cell, latent_variables, initial_cell_state=initial_state) initial_state = cell.zero_state(self.batch_size, dtype=tf.float32) if self.is_training: decoder_emb_inputs = tf.nn.embedding_lookup( self.embedding, self.decoder_input) helper = seq2seq.ScheduledEmbeddingTrainingHelper( decoder_emb_inputs, self.output_len, self.embedding, self.config.SAMP_PROB) else: helper = seq2seq.GreedyEmbeddingHelper(self.embedding, self.go_input(), self.eos_idx) decoder = seq2seq.BasicDecoder(cell, helper, initial_state=initial_state, output_layer=None) outputs, _, seq_len = seq2seq.dynamic_decode( decoder, maximum_iterations=tf.reduce_max(self.output_len)) self.rnn_output = outputs.rnn_output self.proj_weights = projection_layer.kernel self.proj_bias = projection_layer.bias bow_h = tf.layers.dense(self.latent_variables, self.config.BOW_SIZE, activation=tf.tanh) if self.is_training: bow_h = tf.nn.dropout(bow_h, self.config.DROPOUT_KEEP) self.bow_logits = tf.layers.dense(bow_h, self.vocab_size, name="bow_logits")
def build_decoder(self): print('build decoder...') with tf.variable_scope('decoder'): self.decoder_cell, self.decoder_initial_state = \ self.build_decoder_cell() self.decoder_embedding = tf.get_variable( name='embedding', shape=[self.para.decoder_vocab_size, self.para.embedding_size], dtype=self.dtype) output_projection_layer = Dense(units=self.para.decoder_vocab_size, name='output_projection') if self.para.mode == 'train': self.decoder_inputs_embedded = tf.nn.embedding_lookup( params=self.decoder_embedding, ids=self.decoder_inputs) if self.para.scheduled_sampling == 0: training_helper = seq2seq.TrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_len, name='training_helper') else: self.sampling_probability = tf.cond( self.global_step < self.para.start_decay_step * 2, lambda: tf.cast(tf.divide( self.global_step, self.para.start_decay_step * 2), dtype=self.dtype), lambda: tf.constant(1.0, dtype=self.dtype), name='sampling_probability') training_helper = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_len, embedding=self.decoder_embedding, sampling_probability=self.sampling_probability, name='training_helper') training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=output_projection_layer) max_decoder_length = tf.reduce_max(self.decoder_inputs_len) self.decoder_outputs, decoder_states, decoder_outputs_len = \ seq2seq.dynamic_decode( decoder=training_decoder, maximum_iterations=max_decoder_length ) rnn_output = self.decoder_outputs.rnn_output # rnn_output should be padded to max_len # calculation of loss will be handled by masks self.rnn_output_padded = tf.pad(rnn_output, \ [[0, 0], [0, self.para.max_len - tf.shape(rnn_output)[1]], [0, 0]] \ ) self.loss = self.compute_loss(logits=self.rnn_output_padded, labels=self.decoder_targets) elif self.para.mode == 'test': start_tokens = tf.fill([self.para.batch_size], 1) if self.para.beam_search == 0: inference_helper = seq2seq.GreedyEmbeddingHelper( start_tokens=start_tokens, end_token=2, embedding=self.decoder_embedding) inference_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=inference_helper, initial_state=self.decoder_initial_state, output_layer=output_projection_layer) else: inference_decoder = seq2seq.BeamSearchDecoder( cell=self.decoder_cell, embedding=self.decoder_embedding, start_tokens=start_tokens, end_token=2, initial_state=self.decoder_initial_state, beam_width=self.para.beam_width, output_layer=output_projection_layer) self.decoder_outputs, decoder_states, decoder_outputs_len = \ seq2seq.dynamic_decode( decoder=inference_decoder, maximum_iterations=self.para.max_len ) if self.para.beam_search == 0: # self.decoder_predictions_id: [batch_size, max_len, 1] self.decoder_predicted_ids = tf.expand_dims( \ input=self.decoder_outputs.sample_id, \ axis=-1 \ ) else: # self.decoder_predicted_ids: [batch_size, <= max_len, beam_width] self.decoder_predicted_ids = self.decoder_outputs.predicted_ids