def _build_decoder(self, decoder_cell, batch_size, lstm_holistic_features, cnn_fmap): embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes) output_layer = Compute2dAttentionLayer(512, self.num_classes, cnn_fmap) if self._is_training: train_helper = seq2seq.TrainingHelper( embedding_fn(self._groundtruth_dict['decoder_inputs']), sequence_length=self._groundtruth_dict['decoder_lengths'], time_major=False) decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=train_helper, initial_state=lstm_holistic_features, output_layer=output_layer) else: lstm0_state_tile = tf.nn.rnn_cell.LSTMStateTuple( tf.tile(lstm_holistic_features[0].c, [self._beam_width, 1]), tf.tile(lstm_holistic_features[0].h, [self._beam_width, 1])) lstm1_state_tile = tf.nn.rnn_cell.LSTMStateTuple( tf.tile(lstm_holistic_features[1].c, [self._beam_width, 1]), tf.tile(lstm_holistic_features[1].h, [self._beam_width, 1])) lstm_holistic_features_tile = (lstm0_state_tile, lstm1_state_tile) decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=tf.fill([batch_size], self.start_label), end_token=self.end_label, initial_state=lstm_holistic_features_tile, beam_width=self._beam_width, output_layer=output_layer, length_penalty_weight=0.0) return decoder
def model(self): with tf.variable_scope("encoder"): encoder_cell = self._create_rnn_cell() source_embedding = tf.get_variable(name="source_embedding", shape=[self.source_vocab_size, self.embedding_size], initializer=tf.initializers.truncated_normal()) encoder_embedding_inputs = tf.nn.embedding_lookup(source_embedding, self.source_input) encoder_outputs, encoder_states = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=encoder_embedding_inputs, dtype=tf.float32) with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): if self.mode=="test": encoder_states = seq2seq.tile_batch(encoder_states, self.beam_size) decoder_cell = self._create_rnn_cell() decoder_cell = rnn.DropoutWrapper(decoder_cell,output_keep_prob=0.5) if self.mode=="test": batch_size = self.batch_size*self.beam_size else: batch_size = self.batch_size #decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,dtype=tf.float32) output_layer = tf.layers.Dense(units=self.target_vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) target_embedding = tf.get_variable(name="target_embedding", shape=[self.target_vocab_size, self.embedding_size]) if self.mode == "train": self.mask = tf.sequence_mask(self.target_length,self.max_target_length,dtype=tf.float32) del_end = tf.strided_slice(self.target_input,[0,0],[self.batch_size,-1],[1,1]) decoder_input = tf.concat([tf.fill([self.batch_size, 1],2),del_end],axis=1) decoder_input_embedding = tf.nn.embedding_lookup(target_embedding,decoder_input) training_helper = seq2seq.TrainingHelper(inputs=decoder_input_embedding, sequence_length=tf.fill([self.batch_size],self.max_target_length)) training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper, initial_state=encoder_states, output_layer=output_layer) decoder_outputs,_,_ = seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False, impute_finished=True, maximum_iterations=self.max_target_length) self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output) self.decoder_predict_train = tf.argmax(self.decoder_logits_train,axis=-1) self.loss_op = tf.reduce_mean(tf.losses.softmax_cross_entropy( onehot_labels=tf.one_hot(self.target_input, depth=self.target_vocab_size), logits=self.decoder_logits_train, weights=self.mask)) optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) trainable_params = tf.trainable_variables() gradients = tf.gradients(self.loss_op, trainable_params) clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm) self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params)) elif self.mode =="test": start_tokens = tf.fill([self.batch_size], value=2) end_token = 3 inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=target_embedding, start_tokens=start_tokens, end_token=end_token, initial_state=encoder_states, beam_width=self.beam_size, output_layer=output_layer) decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder, maximum_iterations=self.max_target_length) print(decoder_outputs.predicted_ids.get_shape().as_list()) self.decoder_predict_decode = decoder_outputs.predicted_ids[:, :, 0]
def _build_decoder(self, decoder_cell, batch_size): embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes) output_layer = tf.layers.Dense( self.num_classes, activation=None, use_bias=True, kernel_initializer=tf.variance_scaling_initializer(), bias_initializer=tf.zeros_initializer()) if self._is_training: train_helper = seq2seq.TrainingHelper( embedding_fn(self._groundtruth_dict['decoder_inputs']), sequence_length=self._groundtruth_dict['decoder_lengths'], time_major=False) decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=train_helper, initial_state=decoder_cell.zero_state(batch_size, tf.float32), output_layer=output_layer) else: decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=tf.fill([batch_size], self.start_label), end_token=self.end_label, initial_state=decoder_cell.zero_state( batch_size * self._beam_width, tf.float32), beam_width=self._beam_width, output_layer=output_layer, length_penalty_weight=0.0) return decoder
def _build_decoder_test_beam_search(self): r""" Builds a beam search test decoder """ if self._hparams.enable_attention is True: cells, initial_state = self._add_attention(self._decoder_cells, beam_search=True) else: # does the non-attentive beam decoder need tile_batch ? cells = self._decoder_cells decoder_initial_state_tiled = seq2seq.tile_batch( # guess so ? it compiles without it too self._decoder_initial_state, multiplier=self._hparams.beam_width) initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([self._batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.6, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids
def inference_decode_layer(self, start_token, dec_cell, end_token, output_layer): start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32), [self.batch_size], name='start_token') tiled_enc_output = seq2seq.tile_batch(self.enc_output, multiplier=self.Beam_width) tiled_enc_state = seq2seq.tile_batch(self.enc_state, multiplier=self.Beam_width) tiled_source_len = seq2seq.tile_batch(self.source_len, multiplier=self.Beam_width) atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2, tiled_enc_output, tiled_source_len, normalize=True) decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech, self.hidden_dim * 2) initial_state = decoder_att.zero_state( self.batch_size * self.Beam_width, tf.float32).clone(cell_state=tiled_enc_state) decoder = seq2seq.BeamSearchDecoder(decoder_att, self.embeddings, start_tokens, end_token, initial_state, beam_width=self.Beam_width, output_layer=output_layer) infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False, self.max_target_len) return infer_logits
def beam_eval_decoder(agenda, embeddings, extended_base_words, oov, start_token_id, stop_token_id, base_sent_hiddens, base_length, vocab_size, attn_dim, hidden_dim, num_layer, max_sentence_length, beam_width, swap_memory, enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False): with tf.variable_scope(OPS_NAME, 'decoder', reuse=True): true_batch_size = tf.shape(base_sent_hiddens)[0] tiled_agenda = seq2seq.tile_batch(agenda, beam_width) tiled_extended_base_words = seq2seq.tile_batch(extended_base_words, beam_width) tiled_oov = seq2seq.tile_batch(oov, beam_width) tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width) tiled_base_lengths = seq2seq.tile_batch(base_length, beam_width) start_token_id = tf.cast(start_token_id, tf.int32) stop_token_id = tf.cast(stop_token_id, tf.int32) cell, zero_states = create_decoder_cell( tiled_agenda, tiled_extended_base_words, tiled_oov, tiled_base_sent, tiled_base_lengths, vocab_size, attn_dim, hidden_dim, num_layer, enable_dropout=enable_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn, beam_width=beam_width) decoder = seq2seq.BeamSearchDecoder(cell, create_embedding_fn(vocab_size), tf.fill([true_batch_size], start_token_id), stop_token_id, zero_states, beam_width=beam_width, length_penalty_weight=0.0) return seq2seq.dynamic_decode(decoder, maximum_iterations=max_sentence_length, swap_memory=swap_memory)
def _build_decoder_beam_search(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) attention_mechanisms, layer_sizes = self._create_attention_mechanisms( beam_search=True) decoder_initial_state_tiled = seq2seq.tile_batch( self._decoder_initial_state, multiplier=self._hparams.beam_width) if self._hparams.enable_attention is True: attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=decoder_initial_state_tiled, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention) initial_state = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width) initial_state = initial_state.clone( cell_state=decoder_initial_state_tiled) cells = attention_cells else: cells = self._decoder_cells initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.5, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary( states) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids self.beam_search_output = outputs.beam_search_decoder_output
def _build_decoder_test_beam_search(self): r""" Builds a beam search test decoder """ if self._hparams.enable_attention is True: cells, initial_state = add_attention( cells=self._decoder_cells, attention_types=self._hparams.attention_type[1], num_units=self._hparams.decoder_units_per_layer[-1], memory=self._encoder_memory, memory_len=self._encoder_features_len, beam_search=True, batch_size=self._batch_size, beam_width=self._hparams.beam_width, initial_state=self._decoder_initial_state, mode=self._mode, dtype=self._hparams.dtype, fusion_type='linear_fusion', write_attention_alignment=self._hparams. write_attention_alignment) else: # does the non-attentive beam decoder need tile_batch ? cells = self._decoder_cells decoder_initial_state_tiled = seq2seq.tile_batch( # guess so ? it compiles without it too self._decoder_initial_state, multiplier=self._hparams.beam_width) initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([self._batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.6, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) if self._hparams.write_attention_alignment is True: self.attention_summary, self.attention_alignment = self._create_attention_alignments_summary( states) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids self.beam_search_output = outputs.beam_search_decoder_output
def _build_decoder(self): with tf.variable_scope("dialog_decoder"): with tf.variable_scope("decoder_output_projection"): output_layer = layers_core.Dense( self.config.vocab_size, use_bias=False, name="output_projection") with tf.variable_scope("decoder_rnn"): dec_cell, dec_init_state = self._build_decoder_cell(enc_outputs=self.encoder_outputs, enc_state=self.encoder_state) # Training or Eval if self.mode != ModelMode.infer: # not infer, do decode turn by turn resp_emb_inp = tf.nn.embedding_lookup(self.decoder_embeddings, self.target_input) helper = tc_seq2seq.TrainingHelper(resp_emb_inp, self.target_length) decoder = tc_seq2seq.BasicDecoder( cell=dec_cell, helper=helper, initial_state=dec_init_state, output_layer=output_layer ) dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode(decoder) sample_id = dec_outputs.sample_id logits = dec_outputs.rnn_output else: beam_width = self.config.beam_size length_penalty_weight = self.config.length_penalty_weight maximum_iterations = tf.to_int32(self.config.infer_max_len) start_tokens = tf.fill([self.batch_size], self.config.sos_idx) end_token = self.config.eos_idx # beam size decoder = tc_seq2seq.BeamSearchDecoder( cell=dec_cell, embedding=self.decoder_embeddings, start_tokens=start_tokens, end_token=end_token, initial_state=dec_init_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=length_penalty_weight) dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode( decoder, maximum_iterations=maximum_iterations, ) logits = tf.no_op() sample_id = dec_outputs.predicted_ids self.logits = logits self.sample_id = sample_id
def beam_eval_decoder(agenda, embeddings, start_token_id, stop_token_id, base_sent_hiddens, insert_word_embeds, delete_word_embeds, base_length, iw_length, dw_length, attn_dim, hidden_dim, num_layer, maximum_iterations, beam_width, swap_memory, enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False): with tf.variable_scope(OPS_NAME, 'decoder', reuse=True): true_batch_size = tf.shape(base_sent_hiddens)[0] tiled_agenda = seq2seq.tile_batch(agenda, beam_width) tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width) tiled_insert_embeds = seq2seq.tile_batch(insert_word_embeds, beam_width) tiled_delete_embeds = seq2seq.tile_batch(delete_word_embeds, beam_width) tiled_src_lengths = seq2seq.tile_batch(base_length, beam_width) tiled_iw_lengths = seq2seq.tile_batch(iw_length, beam_width) tiled_dw_lengths = seq2seq.tile_batch(dw_length, beam_width) start_token_id = tf.cast(start_token_id, tf.int32) stop_token_id = tf.cast(stop_token_id, tf.int32) cell = create_decoder_cell( tiled_agenda, tiled_base_sent, tiled_insert_embeds, tiled_delete_embeds, tiled_src_lengths, tiled_iw_lengths, tiled_dw_lengths, attn_dim, hidden_dim, num_layer, enable_dropout=enable_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn ) output_layer = DecoderOutputLayer(embeddings, beam_decoder=True) zero_states = create_trainable_zero_state(cell, true_batch_size, beam_width) decoder = seq2seq.BeamSearchDecoder( cell, embeddings, tf.fill([true_batch_size], start_token_id), stop_token_id, zero_states, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0 ) return seq2seq.dynamic_decode(decoder, maximum_iterations=maximum_iterations, swap_memory=swap_memory)
def get_beam_ids(self, cell, projection_layer): initial_state = cell.zero_state(self.batch_size * self.config.BEAM_WIDTH, dtype=tf.float32) if self.config.LEN_EMB_SIZE > 0: output_seq_len = seq2seq.tile_batch( self.output_len, multiplier=self.config.BEAM_WIDTH) cell = LenControlWrapper(cell, output_seq_len, self.len_embeddings, initial_cell_state=initial_state) initial_state = cell.zero_state(self.batch_size * self.config.BEAM_WIDTH, dtype=tf.float32) latent_variables = seq2seq.tile_batch( self.latent_variables, multiplier=self.config.BEAM_WIDTH) cell = AlignmentWrapper(cell, latent_variables, initial_cell_state=initial_state) initial_state = cell.zero_state(self.batch_size * self.config.BEAM_WIDTH, dtype=tf.float32) if not self.is_training: decoder = seq2seq.BeamSearchDecoder( cell, self.embedding, self.go_input(), self.eos_idx, initial_state=initial_state, beam_width=self.config.BEAM_WIDTH, output_layer=projection_layer) outputs, _, seq_len = seq2seq.dynamic_decode( decoder, maximum_iterations=tf.reduce_max(self.output_len)) return outputs.predicted_ids[:, :, 0]
def _build_decoder_action(model, dialogue_state, hparams, start_token, end_token, output_layer): """build the decoder for action states.""" iterator = model.iterator start_token_id = tf.cast( model.vocab_table.lookup(tf.constant(start_token)), tf.int32) end_token_id = tf.cast( model.vocab_table.lookup(tf.constant(end_token)), tf.int32) start_tokens = tf.fill([model.batch_size], start_token_id) end_token = end_token_id # kb is not used again ## Decoder. with tf.variable_scope("action_decoder") as decoder_scope: # we initialize the cell with the last layer of the last hidden state cell, decoder_initial_state = _build_action_decoder_cell( model, hparams, dialogue_state, model.global_gpu_num) model.global_gpu_num += 1 ## Train or eval # situation one, for train, eval, mutable train # decoder_emp_inp: [max_time, batch_size, num_units] action = iterator.action # shift action paddings = tf.constant([[0, 0], [1, 0]]) action = tf.pad(action, paddings, "CONSTANT", constant_values=0)[:, :-1] decoder_emb_inp = tf.nn.embedding_lookup(model.embedding_decoder, action) # Helper helper_train = seq2seq.TrainingHelper( decoder_emb_inp, iterator.action_len, time_major=False) # Decoder my_decoder_train = seq2seq.BasicDecoder( cell, helper_train, decoder_initial_state, output_layer) # Dynamic decoding outputs_train, _, _ = seq2seq.dynamic_decode( my_decoder_train, output_time_major=False, swap_memory=True, scope=decoder_scope) sample_id_train = outputs_train.sample_id logits_train = outputs_train.rnn_output # inference beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: my_decoder_infer = seq2seq.BeamSearchDecoder( cell=cell, embedding=model.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=length_penalty_weight) else: # Helper if model.mode in dialogue_utils.self_play_modes: helper_infer = seq2seq.SampleEmbeddingHelper( model.embedding_decoder, start_tokens, end_token) else: helper_infer = seq2seq.GreedyEmbeddingHelper( model.embedding_decoder, start_tokens, end_token) # Decoder my_decoder_infer = seq2seq.BasicDecoder( cell, helper_infer, decoder_initial_state, output_layer=output_layer # applied per timestep ) # Dynamic decoding outputs_infer, _, _ = seq2seq.dynamic_decode( my_decoder_infer, maximum_iterations=hparams.len_action, output_time_major=False, swap_memory=True, scope=decoder_scope) if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: logits_infer = tf.no_op() sample_id_infer = outputs_infer.predicted_ids else: logits_infer = outputs_infer.rnn_output sample_id_infer = outputs_infer.sample_id return logits_train, logits_infer, sample_id_train, sample_id_infer
def create_model_predict(self, input, mode='decode'): use_beam_search = False if self.params.beam_with > 1: use_beam_search = True with tf.variable_scope("attetnion_seq2seq", reuse=tf.AUTO_REUSE): embeddings_matrix = self._create_embedding() keep_prob = 1 - self.params.dropout_rate batch_size = tf.shape(input)[0] # encoder encoder_outputs, encoder_last_states, encoder_inputs_length = self._create_encoder( embeddings_matrix, input, keep_prob) # decoder with tf.variable_scope('decoder'): # # Output projection layer to convert cell_outpus to logits output_layer = Dense(self.params.vocab_size, name='output_project') input_layer = Dense(self.params.hidden_units * 2, dtype=tf.float32, name='input_projection') decoder_cell, decoder_initial_state = create_decoder_cell( enc_outputs=encoder_outputs, enc_states=encoder_last_states, enc_seq_len=encoder_inputs_length, num_layers=self.params.depth, num_units=self.params.hidden_units * 2, keep_prob=keep_prob, use_residual=self.params.use_residual, use_beam_search=use_beam_search, beam_size=self.params.beam_with, batch_size=batch_size, top_attention=self.params.top_attention) # Start_tokens: [batch_size,] `int32` vector start_tokens = tf.ones([ batch_size, ], tf.int32) * data_utils.GO_ID end_token = data_utils.EOS_ID def embed_and_input_proj(inputs): return input_layer( tf.nn.embedding_lookup(embeddings_matrix, inputs)) if self.params.beam_with <= 1: decode_helper = seq2seq.GreedyEmbeddingHelper( start_tokens=start_tokens, end_token=end_token, embedding=embed_and_input_proj) inference_decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=decode_helper, initial_state=decoder_initial_state, output_layer=output_layer) decoder_output, _, _ = seq2seq.dynamic_decode( decoder=inference_decoder, output_time_major=False, impute_finished=True, maximum_iterations=self.params.max_seq_length) else: inference_decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embed_and_input_proj, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=self.params.beam_with, output_layer=output_layer) decoder_output, _, _ = seq2seq.dynamic_decode( decoder=inference_decoder, output_time_major=False, maximum_iterations=self.params.max_seq_length) if self.params.beam_with <= 1: decoder_predict = tf.expand_dims(decoder_output.sample_id, -1) else: decoder_predict = decoder_output.predicted_ids decoder_predict = tf.identity(decoder_predict, 'predicts') return decoder_predict
def build_decoder(self): with tf.variable_scope("decoder"): decoder_cell, decoder_initial_state = self.build_decoder_cell() # start tokens : [batch_size], which is fed to BeamsearchDecoder during inference start_tokens = tf.ones([self.batch_size], dtype=tf.int32) * data_util.ID_GO end_token = data_util.ID_EOS input_layer = Dense(self.state_size * 2, dtype=tf.float32, name="input_layer") output_layer = Dense(self.decoder_vocab_size, name="output_projection") if self.mode == "train": # feed ground truth decoder input token every time step decoder_input_lookup = tf.nn.embedding_lookup( self.embedding_matrix, self.decoder_input) decoder_input_lookup = input_layer(decoder_input_lookup) training_helper = seq2seq.TrainingHelper( inputs=decoder_input_lookup, sequence_length=self.decoder_train_len, name="training_helper") training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, initial_state=decoder_initial_state, helper=training_helper, output_layer=output_layer) # decoder_outputs_train: BasicDecoderOutput # namedtuple(rnn_outputs, sample_id) # decoder_outputs_train.rnn_output: [batch_size, max_time_step + 1, num_decoder_symbols] if output_time_major=False # [max_time_step + 1, batch_size, num_decoder_symbols] if output_time_major=True # decoder_outputs_train.sample_id: [batch_size], tf.int32 max_decoder_len = tf.reduce_max(self.decoder_train_len) decoder_outputs_train, final_state, _ = seq2seq.dynamic_decode( training_decoder, impute_finished=True, swap_memory=True, maximum_iterations=max_decoder_len) self.decoder_logits_train = tf.identity( decoder_outputs_train.rnn_output) decoder_pred = tf.argmax(self.decoder_logits_train, axis=2) # sequence mask for get valid sequence except zero padding weights = tf.sequence_mask(self.decoder_len, maxlen=max_decoder_len, dtype=tf.float32) # compute cross entropy loss for all sequence prediction and ignore loss from zero padding self.loss = seq2seq.sequence_loss( logits=self.decoder_logits_train, targets=self.decoder_target, weights=weights, average_across_batch=True, average_across_timesteps=True) tf.summary.scalar("loss", self.loss) with tf.variable_scope("train_optimizer") and tf.device( "/device:GPU:1"): # use AdamOptimizer and clip gradient by max_norm 5.0 # use global step for counting every iteration params = tf.trainable_variables() gradients = tf.gradients(self.loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0) learning_rate = tf.train.exponential_decay(self.lr, self.global_step, 10000, 0.96) opt = tf.train.AdagradOptimizer(learning_rate) self.train_op = opt.apply_gradients( zip(clipped_gradients, params), global_step=self.global_step) elif self.mode == "test": def embedding_proj(inputs): return input_layer( tf.nn.embedding_lookup(self.embedding_matrix, inputs)) inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding_proj, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=self.beam_depth, output_layer=output_layer) # For GreedyDecoder, return # decoder_outputs_decode: BasicDecoderOutput instance # namedtuple(rnn_outputs, sample_id) # decoder_outputs_decode.rnn_output: [batch_size, max_time_step, num_decoder_symbols] if output_time_major=False # [max_time_step, batch_size, num_decoder_symbols] if output_time_major=True # decoder_outputs_decode.sample_id: [batch_size, max_time_step], tf.int32 if output_time_major=False # [max_time_step, batch_size], tf.int32 if output_time_major=True # For BeamSearchDecoder, return # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance # namedtuple(predicted_ids, beam_search_decoder_output) # decoder_outputs_decode.predicted_ids: [batch_size, max_time_step, beam_width] if output_time_major=False # [max_time_step, batch_size, beam_width] if output_time_major=True # decoder_outputs_decode.beam_search_decoder_output: BeamSearchDecoderOutput instance # namedtuple(scores, predicted_ids, parent_ids) with tf.device("/device:GPU:1"): decoder_outputs, decoder_last_state, decoder_output_length = \ seq2seq.dynamic_decode(decoder=inference_decoder, output_time_major=False, swap_memory=True, maximum_iterations=self.max_iter) self.decoder_pred_test = decoder_outputs.predicted_ids
attention_mechanism = seq2seq.BahdanauAttention( num_units=hidden_dim * 2, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) #decoder_cell = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(hidden_dim*2) for _ in range(num_layers)]) decoder_cell = seq2seq.AttentionWrapper( cell=global_decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=hidden_dim * 2) inference_decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=no_op_embedding, start_tokens=tf.fill([batch_size], 12), end_token=0, initial_state=decoder_cell.zero_state( batch_size * beam_width, tf.float32).clone(cell_state=encoder_last_state), beam_width=beam_width, #initial_state = decoder_cell_inf.zero_state(batch_size = batch_size, dtype = tf.float32) output_layer=projection_layer) print(inference_decoder) with tf.variable_scope('decode_with_shared_attention', reuse=True): inference_decoder_output, _, _ = seq2seq.dynamic_decode( decoder=inference_decoder, impute_finished=False, maximum_iterations=tf.reduce_max(encoder_inputs_length)) for var in tf.trainable_variables(): print(var)
def _build_sentence_decoder(self, inputs, context_encoder_outputs, sentence_encoder_final_states, sentence_encoder_outputs): batch_size = self._batch_size num_sentence = self._num_sentence word_embedding = model_helper.create_word_embedding( num_vocab=self.hparams.num_vocab, embedding_dim=self.hparams.word_embedding_dim, name='decoder/word_embedding', pretrained_word_matrix=self.hparams.pretrained_word_path) # tile_batch in inference mode beam_width = self.hparams.beam_width if self.mode == tf.contrib.learn.ModeKeys.INFER: # only decode last timestep if 'lstm' in self.hparams.rnn_cell_type.lower(): batched_sentence_encoder_states = [] for encoder_state in sentence_encoder_final_states: target_shape = tf.stack([batch_size, num_sentence, -1]) c = s2s.tile_batch( tf.reshape(encoder_state.c, target_shape)[:, -1, :], beam_width) h = s2s.tile_batch( tf.reshape(encoder_state.h, target_shape)[:, -1, :], beam_width) batched_sentence_encoder_states.append( tf.contrib.rnn.LSTMStateTuple(c=c, h=h)) else: batched_sentence_encoder_states = [ s2s.tile_batch( tf.reshape(encoder_state, tf.stack([batch_size, num_sentence, -1]))[:, -1, :], beam_width) for encoder_state in sentence_encoder_final_states ] sentence_encoder_final_states = tuple( batched_sentence_encoder_states) sentence_encoder_outputs = s2s.tile_batch( tf.reshape( sentence_encoder_outputs, tf.stack([ batch_size, num_sentence, -1, self.hparams.num_rnn_units ]))[:, -1, :, :], beam_width) source_lengths = s2s.tile_batch(inputs.src_lengths[:, -1], beam_width) context_encoder_outputs = tf.reshape( context_encoder_outputs, tf.stack( [batch_size, num_sentence, self.hparams.num_rnn_units]))[:, -1, :] context_encoder_outputs = tf.tile( tf.expand_dims(context_encoder_outputs, axis=1), [1, beam_width, 1]) effective_batch_size = self._batch_size * beam_width else: source_lengths = tf.reshape(inputs.src_lengths, [-1]) context_encoder_outputs.set_shape( [None, self.hparams.num_rnn_units]) effective_batch_size = self._batch_size * self._num_sentence # Current strategy: No residual layers at decoder attention_mechanism = model_helper.create_attention_mechanism( attention_option=self.hparams.attention_type, num_units=self.hparams.num_rnn_units, memory=sentence_encoder_outputs, source_length=source_lengths) decoder_cell = s2s.AttentionWrapper( model_helper.create_rnn_cell( cell_type=self.hparams.rnn_cell_type, num_layers=self.hparams.num_rnn_layers, num_units=self.hparams.num_rnn_units, dropout_keep_prob=self._dropout_keep_prob, num_residual_layers=0), attention_mechanism, attention_layer_size=self.hparams.num_rnn_units, alignment_history=False, name="attention") decoder_initial_state = decoder_cell.zero_state( effective_batch_size, tf.float32) decoder_initial_state = decoder_initial_state.clone( cell_state=sentence_encoder_final_states) with tf.variable_scope('output_projection'): output_layer = layers_core.Dense(self.hparams.num_vocab, name="output_projection") self.output_layer = output_layer if self.mode in { tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL }: decoder_input_tokens = tf.reshape( inputs.targets_in, tf.stack([batch_size * num_sentence, -1])) decoder_inputs = tf.nn.embedding_lookup(word_embedding, decoder_input_tokens) target_lengths = tf.reshape(inputs.tgt_lengths, [-1]) if self.mode == tf.contrib.learn.ModeKeys.TRAIN and False: sampling_probability = 1.0 - tf.train.exponential_decay( 1.0, self.global_step, self.hparams.scheduled_sampling_decay_steps, self.hparams.scheduled_sampling_decay_rate, staircase=True, name='scheduled_sampling_prob') helper = s2s.ScheduledEmbeddingTrainingHelper( inputs=decoder_inputs, sequence_length=target_lengths, embedding=word_embedding, sampling_probability=sampling_probability, name='scheduled_sampling_helper') else: helper = s2s.TrainingHelper( inputs=decoder_inputs, sequence_length=target_lengths, name='training_helper', ) decoder = s2s.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=None) final_outputs, final_state, _ = dynamic_decode_with_concat( decoder, context_encoder_outputs, swap_memory=True) logits = final_outputs.rnn_output sample_id = final_outputs.sample_id else: sos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.SOS)), tf.int32) eos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.EOS)), tf.int32) sos_ids = tf.fill([batch_size], sos_id) decoder = s2s.BeamSearchDecoder( cell=decoder_cell, embedding=word_embedding, start_tokens=sos_ids, end_token=eos_id, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer) final_outputs, final_state, _ = dynamic_decode_with_concat( decoder, context_encoder_outputs, maximum_iterations=self.hparams.target_max_length, swap_memory=True) logits = final_outputs.beam_search_decoder_output.scores sample_id = final_outputs.predicted_ids return logits, final_state, sample_id
def decode(self, encoder_outputs, encoder_state, source_sequence_length): with tf.variable_scope("Decoder") as scope: beam_width = self.beam_width decoder_type = self.decoder_type seq_max_len = self.seq_max_len batch_size = tf.shape(encoder_outputs)[0] if self.path_embed_method == "lstm": self.decoder_cell = self._build_decode_cell() if self.mode == "test" and beam_width > 0: memory = seq2seq.tile_batch(self.encoder_outputs, multiplier=beam_width) source_sequence_length = seq2seq.tile_batch(self.source_sequence_length, multiplier=beam_width) encoder_state = seq2seq.tile_batch(self.encoder_state, multiplier=beam_width) batch_size = self.batch_size * beam_width else: memory = encoder_outputs source_sequence_length = source_sequence_length encoder_state = encoder_state attention_mechanism = seq2seq.BahdanauAttention(self.hidden_layer_dim, memory, memory_sequence_length=source_sequence_length) self.decoder_cell = seq2seq.AttentionWrapper(self.decoder_cell, attention_mechanism, attention_layer_size=self.hidden_layer_dim) self.decoder_initial_state = self.decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state) projection_layer = Dense(self.word_vocab_size, use_bias=False) """For training the model""" if self.mode == "train": decoder_train_helper = tf.contrib.seq2seq.TrainingHelper(self.decoder_train_inputs_embedded, self.decoder_train_length) decoder_train = seq2seq.BasicDecoder(self.decoder_cell, decoder_train_helper, self.decoder_initial_state, projection_layer) decoder_outputs_train, decoder_states_train, decoder_seq_len_train = seq2seq.dynamic_decode(decoder_train) decoder_logits_train = decoder_outputs_train.rnn_output self.decoder_logits_train = tf.reshape(decoder_logits_train, [batch_size, -1, self.word_vocab_size]) """For test the model""" # if self.mode == "infer" or self.if_pred_on_dev: if decoder_type == "greedy": decoder_infer_helper = seq2seq.GreedyEmbeddingHelper(self.word_embeddings, tf.ones([batch_size], dtype=tf.int32), self.EOS) decoder_infer = seq2seq.BasicDecoder(self.decoder_cell, decoder_infer_helper, self.decoder_initial_state, projection_layer) elif decoder_type == "beam": decoder_infer = seq2seq.BeamSearchDecoder(cell=self.decoder_cell, embedding=self.word_embeddings, start_tokens=tf.ones([batch_size], dtype=tf.int32), end_token=self.EOS, initial_state=self.decoder_initial_state, beam_width=beam_width, output_layer=projection_layer) decoder_outputs_infer, decoder_states_infer, decoder_seq_len_infer = seq2seq.dynamic_decode(decoder_infer, maximum_iterations=seq_max_len) if decoder_type == "beam": self.decoder_logits_infer = tf.no_op() self.sample_id = decoder_outputs_infer.predicted_ids elif decoder_type == "greedy": self.decoder_logits_infer = decoder_outputs_infer.rnn_output self.sample_id = decoder_outputs_infer.sample_id
def __init__(self, vocab_size, hidden_size, dropout, num_layers, max_gradient_norm, batch_size, learning_rate, lr_decay_factor, max_target_length, max_source_length, decoder_mode=False): ''' vocab_size: number of vocab tokens buckets: buckets of max sequence lengths hidden_size: dimension of hidden layers num_layers: number of hidden layers max_gradient_norm: maximum gradient magnitude batch_size: number of training examples fed to network at once learning_rate: starting learning rate of network lr_decay_factor: amount by which to decay learning rate num_samples: number of samples for sampled softmax decoder_mode: Whether to build backpass nodes or not ''' GO_ID = config.GO_ID EOS_ID = config.EOS_ID self.max_source_length = max_source_length self.max_target_length = max_target_length self.vocab_size = vocab_size self.batch_size = batch_size self.global_step = tf.Variable(0, trainable=False) self.learning_rate = learning_rate self.encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs') self.source_lengths = tf.placeholder(shape=(None, ), dtype=tf.int32, name='source_lengths') self.decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets') self.target_lengths = tf.placeholder(shape=(None, ), dtype=tf.int32, name="target_lengths") with tf.variable_scope('embeddings') as scope: embeddings = tf.Variable(tf.random_uniform( [vocab_size, hidden_size], -1.0, 1.0), dtype=tf.float32) encoder_inputs_embedded = tf.nn.embedding_lookup( embeddings, self.encoder_inputs) targets_embedding = tf.nn.embedding_lookup(embeddings, self.decoder_targets) with tf.variable_scope('encoder') as scope: encoder_cell = rnn.LSTMCell(hidden_size) encoder_cell = rnn.DropoutWrapper(encoder_cell, input_keep_prob=dropout) encoder_cell = rnn.MultiRNNCell([encoder_cell] * num_layers) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell=encoder_cell, inputs=encoder_inputs_embedded, sequence_length=self.source_lengths, dtype=tf.float32, time_major=False) with tf.variable_scope('decoder') as scope: decoder_cell = rnn.LSTMCell(hidden_size) decoder_cell = rnn.DropoutWrapper(decoder_cell, input_keep_prob=dropout) decoder_cell = rnn.MultiRNNCell([decoder_cell] * num_layers, state_is_tuple=True) if decoder_mode: beam_width = 2 decoder = seq2seq.BeamSearchDecoder(embedding=embeddings, start_tokens=tf.tile( [GOD_ID], [batch_size]), end_token=EOS_ID, initial_state=encoder_state, beam_width=2) self.logits = final_outputs.predicted_ids else: helper = seq2seq.TrainingHelper(targets_embedding, self.target_lengths) decoder = seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, Dense(vocab_size)) final_outputs, final_state, final_sequence_lengths =\ seq2seq.dynamic_decode(decoder=decoder) self.logits = final_outputs.rnn_output if not decoder_mode: with tf.variable_scope("loss") as scope: #have to pad logits, dynamic decode produces results not consistent #in shape with targets pad_size = self.max_target_length - tf.reduce_max( final_sequence_lengths) self.logits = tf.pad(self.logits, [[0, 0], [0, pad_size], [0, 0]]) weights = tf.sequence_mask(lengths=final_sequence_lengths, maxlen=self.max_target_length, dtype=tf.float32, name='weights') x_entropy_loss = seq2seq.sequence_loss( logits=self.logits, targets=self.decoder_targets, weights=weights) #cross-entropy loss function self.loss = tf.reduce_mean(x_entropy_loss) optimizer = tf.train.AdamOptimizer() #Adam optimization algorithm gradients = optimizer.compute_gradients(x_entropy_loss) capped_grads = [(tf.clip_by_value(grad, -max_gradient_norm, max_gradient_norm), var) for grad, var in gradients] self.train_op = optimizer.apply_gradients( capped_grads, global_step=self.global_step) self.saver = tf.train.Saver(tf.global_variables())
def __init__(self, vocab_size, embed_size, num_unit, latent_dim, emoji_dim, batch_size, kl_ceiling, bow_ceiling, decoder_layer=1, start_i=1, end_i=2, beam_width=0, maximum_iterations=50, max_gradient_norm=5, lr=1e-3, dropout=0.2, num_gpu=2, cell_type=tf.nn.rnn_cell.GRUCell, is_seq2seq=False): self.ori_sample = None self.rep_sample = None self.out_sample = None self.sess = None self.loss_weight = tf.placeholder_with_default(0., shape=()) self.policy_weight = tf.placeholder_with_default(1., shape=()) self.ac_vec = tf.placeholder(tf.float32, shape=[batch_size], name="accuracy_vector") self.ac5_vec = tf.placeholder(tf.float32, shape=[batch_size], name="top5_accuracy_vector") self.is_policy = tf.placeholder_with_default(False, shape=()) shape = [batch_size, latent_dim] self.rdm = tf.placeholder_with_default(np.zeros(shape, dtype=np.float32), shape=shape) self.q_rdm = tf.placeholder_with_default(np.zeros(shape, dtype=np.float32), shape=shape) self.end_i = end_i self.batch_size = batch_size self.num_gpu = num_gpu self.num_unit = num_unit self.dropout = tf.placeholder_with_default(dropout, (), name="dropout") self.beam_width = beam_width self.cell_type = cell_type self.emoji = tf.placeholder(tf.int32, shape=[batch_size], name="emoji") self.ori = tf.placeholder(tf.int32, shape=[None, batch_size], name="original_tweet") # [len, batch_size] self.ori_len = tf.placeholder(tf.int32, shape=[batch_size], name="original_tweet_length") self.rep = tf.placeholder(tf.int32, shape=[None, batch_size], name="response_tweet") self.rep_len = tf.placeholder(tf.int32, shape=[batch_size], name="response_tweet_length") self.rep_input = tf.placeholder(tf.int32, shape=[None, batch_size], name="response_start_tag") self.rep_output = tf.placeholder(tf.int32, shape=[None, batch_size], name="response_end_tag") self.reward = tf.placeholder(tf.float32, shape=[batch_size], name="reward") self.kl_weight = tf.placeholder_with_default(1., shape=(), name="kl_weight") self.placeholders = [ self.emoji, self.ori, self.ori_len, self.rep, self.rep_len, self.rep_input, self.rep_output ] with tf.variable_scope("embeddings"): embedding = Embedding(vocab_size, embed_size) ori_emb = embedding( self.ori) # [max_len, batch_size, embedding_size] rep_emb = embedding(self.rep) rep_input_emb = embedding(self.rep_input) emoji_emb = embedding(self.emoji) # [batch_size, embedding_size] with tf.variable_scope("original_tweet_encoder"): ori_encoder_output, ori_encoder_state = build_bidirectional_rnn( num_unit, ori_emb, self.ori_len, cell_type, num_gpu, self.dropout, base_gpu=0) ori_encoder_state_flat = tf.concat( [ori_encoder_state[0], ori_encoder_state[1]], axis=1) emoji_vec = tf.layers.dense(emoji_emb, emoji_dim, activation=tf.nn.tanh) self.emoji_vec = emoji_emb # emoji_vec = tf.ones([batch_size, emoji_dim], tf.float32) condition_flat = tf.concat([ori_encoder_state_flat, emoji_vec], axis=1) with tf.variable_scope("response_tweet_encoder"): _, rep_encoder_state = build_bidirectional_rnn(num_unit, rep_emb, self.rep_len, cell_type, num_gpu, self.dropout, base_gpu=2) rep_encoder_state_flat = tf.concat( [rep_encoder_state[0], rep_encoder_state[1]], axis=1) with tf.variable_scope("representation_network"): rn_input = tf.concat([rep_encoder_state_flat, condition_flat], axis=1) # simpler representation network # r_hidden = rn_input r_hidden = tf.layers.dense( rn_input, latent_dim, activation=tf.nn.relu, name="r_net_hidden") # int(1.6 * latent_dim) r_hidden_mu = tf.layers.dense( r_hidden, latent_dim, activation=tf.nn.relu) # int(1.3 * latent_dim) r_hidden_var = tf.layers.dense(r_hidden, latent_dim, activation=tf.nn.relu) self.mu = tf.layers.dense(r_hidden_mu, latent_dim, activation=tf.nn.tanh, name="q_mean") self.log_var = tf.layers.dense(r_hidden_var, latent_dim, activation=tf.nn.tanh, name="q_log_var") with tf.variable_scope("prior_network"): # simpler prior network # p_hidden = condition_flat p_hidden = tf.layers.dense(condition_flat, int(0.62 * latent_dim), activation=tf.nn.relu, name="r_net_hidden") p_hidden_mu = tf.layers.dense(p_hidden, int(0.77 * latent_dim), activation=tf.nn.relu) p_hidden_var = tf.layers.dense(p_hidden, int(0.77 * latent_dim), activation=tf.nn.relu) self.p_mu = tf.layers.dense(p_hidden_mu, latent_dim, activation=tf.nn.tanh, name="p_mean") self.p_log_var = tf.layers.dense(p_hidden_var, latent_dim, activation=tf.nn.tanh, name="p_log_var") with tf.variable_scope("reparameterization"): self.normal = tf.cond( self.is_policy, lambda: self.rdm, lambda: tf.random_normal(shape=tf.shape(self.mu))) self.z_sample = self.mu + tf.exp(self.log_var / 2.) * self.normal self.q_normal = tf.cond( self.is_policy, lambda: self.q_rdm, lambda: tf.random_normal(shape=tf.shape(self.p_mu))) self.q_z_sample = self.p_mu + tf.exp( self.p_log_var / 2.) * self.q_normal if is_seq2seq: self.z_sample = self.z_sample - self.z_sample self.q_z_sample = self.q_z_sample - self.q_z_sample with tf.variable_scope("decoder_train") as decoder_scope: if decoder_layer == 2: train_decoder_init_state = ( tf.concat([self.z_sample, ori_encoder_state[0], emoji_vec], axis=1), tf.concat([self.z_sample, ori_encoder_state[1], emoji_vec], axis=1)) dim = latent_dim + num_unit + emoji_dim cell = tf.nn.rnn_cell.MultiRNNCell([ create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout), create_rnn_cell(dim, 3, cell_type, num_gpu, self.dropout) ]) else: train_decoder_init_state = tf.concat( [self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1) dim = latent_dim + 2 * num_unit + emoji_dim cell = create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout) with tf.variable_scope("attention"): memory = tf.concat( [ori_encoder_output[0], ori_encoder_output[1]], axis=2) memory = tf.transpose(memory, [1, 0, 2]) attention_mechanism = seq2seq.LuongAttention( dim, memory, memory_sequence_length=self.ori_len, scale=True) # attention_mechanism = seq2seq.BahdanauAttention( # num_unit, memory, memory_sequence_length=self.ori_len) decoder_cell = seq2seq.AttentionWrapper( cell, attention_mechanism, attention_layer_size=dim ) # TODO: add_name; what atten layer size means # decoder_cell = cell helper = seq2seq.TrainingHelper(rep_input_emb, self.rep_len + 1, time_major=True) projection_layer = layers_core.Dense(vocab_size, use_bias=False, name="output_projection") decoder = seq2seq.BasicDecoder( decoder_cell, helper, decoder_cell.zero_state( batch_size, tf.float32).clone(cell_state=train_decoder_init_state), output_layer=projection_layer) train_outputs, _, _ = seq2seq.dynamic_decode( decoder, output_time_major=True, swap_memory=True, scope=decoder_scope) self.logits = train_outputs.rnn_output with tf.variable_scope("decoder_infer") as decoder_scope: # normal_sample = tf.random_normal(shape=(batch_size, latent_dim)) if decoder_layer == 2: infer_decoder_init_state = (tf.concat( [self.q_z_sample, ori_encoder_state[0], emoji_vec], axis=1), tf.concat([ self.q_z_sample, ori_encoder_state[1], emoji_vec ], axis=1)) else: infer_decoder_init_state = tf.concat( [self.q_z_sample, ori_encoder_state_flat, emoji_vec], axis=1) start_tokens = tf.fill([batch_size], start_i) end_token = end_i if beam_width > 0: infer_decoder_init_state = seq2seq.tile_batch( infer_decoder_init_state, multiplier=beam_width) decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding.coder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_cell.zero_state( batch_size * beam_width, tf.float32).clone(cell_state=infer_decoder_init_state), beam_width=beam_width, output_layer=projection_layer, length_penalty_weight=0.0) else: helper = seq2seq.GreedyEmbeddingHelper(embedding.coder, start_tokens, end_token) decoder = seq2seq.BasicDecoder( decoder_cell, helper, decoder_cell.zero_state( batch_size, tf.float32).clone(cell_state=infer_decoder_init_state), output_layer=projection_layer # applied per timestep ) # Dynamic decoding infer_outputs, _, infer_lengths = seq2seq.dynamic_decode( decoder, maximum_iterations=maximum_iterations, output_time_major=True, swap_memory=True, scope=decoder_scope) if beam_width > 0: self.result = infer_outputs.predicted_ids else: self.result = infer_outputs.sample_id self.result_lengths = infer_lengths with tf.variable_scope("loss"): max_time = tf.shape(self.rep_output)[0] with tf.variable_scope("reconstruction"): # TODO: use inference decoder's logits to compute recon_loss cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( # ce = [len, batch_size] labels=self.rep_output, logits=self.logits) # rep: [len, batch_size]; logits: [len, batch_size, vocab_size] target_mask = tf.sequence_mask(self.rep_len + 1, max_time, dtype=self.logits.dtype) # time_major target_mask_t = tf.transpose(target_mask) # max_len batch_size self.recon_losses = tf.reduce_sum(cross_entropy * target_mask_t, axis=0) self.recon_loss = tf.reduce_sum( cross_entropy * target_mask_t) / batch_size with tf.variable_scope("latent"): # without prior network # self.kl_loss = 0.5 * tf.reduce_sum(tf.exp(self.log_var) + self.mu ** 2 - 1. - self.log_var, 0) self.kl_losses = 0.5 * tf.reduce_sum( tf.exp(self.log_var - self.p_log_var) + (self.mu - self.p_mu)**2 / tf.exp(self.p_log_var) - 1. - self.log_var + self.p_log_var, axis=1) self.kl_loss = tf.reduce_mean(self.kl_losses) with tf.variable_scope("bow"): # self.bow_loss = self.kl_weight * 0 mlp_b = layers_core.Dense(vocab_size, use_bias=False, name="MLP_b") # is it a mistake that we only model on latent variable? latent_logits = mlp_b( tf.concat( [self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1)) # [batch_size, vocab_size] latent_logits = tf.expand_dims( latent_logits, 0) # [1, batch_size, vocab_size] latent_logits = tf.tile( latent_logits, [max_time, 1, 1]) # [max_time, batch_size, vocab_size] cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( # ce = [len, batch_size] labels=self.rep_output, logits=latent_logits) self.bow_losses = tf.reduce_sum(cross_entropy * target_mask_t, axis=0) self.bow_loss = tf.reduce_sum( cross_entropy * target_mask_t) / batch_size if is_seq2seq: self.kl_losses = self.kl_losses - self.kl_losses self.bow_losses = self.bow_losses - self.bow_losses self.kl_loss = self.kl_loss - self.kl_loss self.bow_loss = self.bow_loss - self.bow_loss self.losses = self.recon_losses + self.kl_losses * self.kl_weight * kl_ceiling + self.bow_losses * bow_ceiling self.loss = tf.reduce_mean(self.losses) # Calculate and clip gradients with tf.variable_scope("optimization"): params = tf.trainable_variables() gradients = tf.gradients(self.loss, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, max_gradient_norm) # Optimization optimizer = tf.train.AdamOptimizer(lr) self.update_step = optimizer.apply_gradients( zip(clipped_gradients, params)) with tf.variable_scope("policy_loss"): prob = tf.nn.softmax( infer_outputs.rnn_output) # [max_len, batch_size, vocab_size] prob = tf.clip_by_value(prob, 1e-15, 1000.) output_prob = tf.reduce_max(tf.log(prob), axis=2) # [max_len, batch_size] seq_log_prob = tf.reduce_sum(output_prob, axis=0) # batch_size # reward = tf.nn.relu(self.reward) self.policy_losses = -self.reward * seq_log_prob self.policy_losses *= (0.5 - 1) * self.ac5_vec + 1 with tf.variable_scope("policy_optimization"): # zero = tf.constant(0, dtype=tf.float32) # where = tf.cast(tf.less(self.reward, zero), tf.float32) # recon = tf.reduce_sum(self.recon_losses * where) / tf.reduce_sum(where) final_loss = self.policy_losses * ( 1 - self.ac_vec) * self.policy_weight final_loss += self.losses * self.loss_weight self.policy_loss = tf.reduce_mean(final_loss) # final_loss = self.losses * self.loss_weight + self.policy_losses * self.policy_weight # final_loss *= (1 - self.ac_vec) # self.policy_loss = tf.reduce_sum(final_loss) / tf.reduce_sum((1 - self.ac_vec)) gradients = tf.gradients(self.policy_loss, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, max_gradient_norm) optimizer = tf.train.AdamOptimizer(lr) self.policy_step = optimizer.apply_gradients( zip(clipped_gradients, params))
def _build_decoder(model, encoder_outputs, encoder_state, hparams, start_token, end_token, output_layer, aux_hidden_state): """build decoder for the seq2seq model.""" iterator = model.iterator start_token_id = tf.cast( model.vocab_table.lookup(tf.constant(start_token)), tf.int32) end_token_id = tf.cast( model.vocab_table.lookup(tf.constant(end_token)), tf.int32) start_tokens = tf.fill([model.batch_size], start_token_id) end_token = end_token_id ## Decoder. with tf.variable_scope("decoder") as decoder_scope: cell, decoder_initial_state = _build_decoder_cell( model, hparams, encoder_state, base_gpu=model.global_gpu_num) model.global_gpu_num += hparams.num_layers # ## Train or eval decoder_emb_inp = tf.nn.embedding_lookup(model.embedding_decoder, iterator.target) # Helper helper_train = help_py.TrainingHelper( decoder_emb_inp, iterator.dialogue_len, time_major=False) # Decoder my_decoder_train = basic_decoder.BasicDecoder( cell, helper_train, decoder_initial_state, encoder_outputs, iterator.turns, output_layer=output_layer, aux_hidden_state=aux_hidden_state) # Dynamic decoding outputs_train, _, _ = seq2seq.dynamic_decode( my_decoder_train, output_time_major=False, swap_memory=True, scope=decoder_scope) sample_id_train = outputs_train.sample_id logits_train = outputs_train.rnn_output ## Inference # else: beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: my_decoder_infer = seq2seq.BeamSearchDecoder( cell=cell, embedding=model.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=length_penalty_weight) else: # Helper if model.mode in dialogue_utils.self_play_modes: helper_infer = seq2seq.SampleEmbeddingHelper( model.embedding_decoder, start_tokens, end_token) else: # inference helper_infer = seq2seq.GreedyEmbeddingHelper( model.embedding_decoder, start_tokens, end_token) # Decoder my_decoder_infer = seq2seq.BasicDecoder( cell, helper_infer, decoder_initial_state, output_layer=output_layer # applied per timestep ) # Dynamic decoding outputs_infer, _, _ = seq2seq.dynamic_decode( my_decoder_infer, maximum_iterations=hparams.max_inference_len, output_time_major=False, swap_memory=True, scope=decoder_scope) if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: logits_infer = tf.no_op() sample_id_infer = outputs_infer.predicted_ids else: logits_infer = outputs_infer.rnn_output sample_id_infer = outputs_infer.sample_id return logits_train, logits_infer, sample_id_train, sample_id_infer
def build_decode(self): # build decoder and attention. with tf.variable_scope('decoder'): self.decoder_cell, self.decoder_initial_state = self.build_decode_cell( ) # Input projection layer to feed embedded inputs to the cell # ** Essential when use_residual=True to match input/output dims input_layer = Dense(self.hidden_units, dtype=self.dtype, name='input_projection') # Output projection layer to convert cell_outpus to logits output_layer = Dense(self.num_decoder_symbols, name='output_project') if self.mode == 'train': # decoder_inputs_embedded: [batch_size, max_time_step + 1, embedding_size] self.decoder_inputs_embedded = tf.nn.embedding_lookup( self.embeddings, self.decoder_inputs_train) # Embedded inputs having gone through input projection layer self.decoder_inputs_embedded = input_layer( self.decoder_inputs_embedded) # Helper to feed inputs for training: read inputs from dense ground truth vectors training_helper = seq2seq.TrainingHelper( inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length_train, time_major=False, name='training_helper') training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=output_layer) #Maximum decoder time_steps in current batch max_decoder_length = tf.reduce_max( self.decoder_inputs_length_train) # decoder_outputs_train: BasicDecoderOutput # namedtuple(rnn_outputs, sample_id) # decoder_outputs_train.rnn_output: [batch_size, max_time_step + 1, num_decoder_symbols] if output_time_major=False # [max_time_step + 1, batch_size, num_decoder_symbols] if output_time_major=True # decoder_outputs_train.sample_id: [batch_size], tf.int32 (self.decoder_outputs_train, self.decoder_last_state_train, self.decoder_outputs_length_train) = (seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=max_decoder_length)) # More efficient to do the projection on the batch-time-concatenated tensor # logits_train: [batch_size, max_time_step + 1, num_decoder_symbols] # self.decoder_logits_train = output_layer(self.decoder_outputs_train.rnn_output) self.decoder_logits_train = tf.identity( self.decoder_outputs_train.rnn_output) # Use argmax to extract decoder symbols to emit self.decoder_pred_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pre_train') # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1] masks = tf.sequence_mask( lengths=self.decoder_inputs_length_train, maxlen=max_decoder_length, dtype=self.dtype, name='masks') self.loss = seq2seq.sequence_loss( logits=self.decoder_logits_train, targets=self.decoder_targets_train, weights=masks, average_across_timesteps=True, average_across_batch=True) # Training summary for the current batch_loss tf.summary.scalar('loss', self.loss) elif self.mode == 'decode': # Start_tokens: [batch_size,] `int32` vector start_token = tf.ones([ self.batch_size, ], tf.int32) * data_utils.GO_ID end_token = data_utils.EOS_ID def embed_and_input_proj(inputs): return input_layer( tf.nn.embedding_lookup(self.embeddings, inputs)) if not self.use_beamsearch_decode: decoding_helper = seq2seq.GreedyEmbeddingHelper( start_tokens=start_token, end_token=end_token, embedding=embed_and_input_proj) inference_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=decoding_helper, initial_state=self.decoder_initial_state, output_layer=output_layer) else: inference_decoder = seq2seq.BeamSearchDecoder( cell=self.decoder_cell, embedding=embed_and_input_proj, start_tokens=start_token, end_token=end_token, initial_state=self.decoder_initial_state, beam_width=self.beam_with, output_layer=output_layer) # For GreedyDecoder, return # decoder_outputs_decode: BasicDecoderOutput instance # namedtuple(rnn_outputs, sample_id) # decoder_outputs_decode.rnn_output: [batch_size, max_time_step, num_decoder_symbols] if output_time_major=False # [max_time_step, batch_size, num_decoder_symbols] if output_time_major=True # decoder_outputs_decode.sample_id: [batch_size, max_time_step], tf.int32 if output_time_major=False # [max_time_step, batch_size], tf.int32 if output_time_major=True # For BeamSearchDecoder, return # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance # namedtuple(predicted_ids, beam_search_decoder_output) # decoder_outputs_decode.predicted_ids: [batch_size, max_time_step, beam_width] if output_time_major=False # [max_time_step, batch_size, beam_width] if output_time_major=True # decoder_outputs_decode.beam_search_decoder_output: BeamSearchDecoderOutput instance # namedtuple(scores, predicted_ids, parent_ids) (self.decoder_outputs_decode, self.decoder_last_state_decode, self.decoder_outputs_length_decode) = (seq2seq.dynamic_decode( decoder=inference_decoder, output_time_major=False, maximum_iterations=self.config.max_decode_step)) if not self.use_beamsearch_decode: # decoder_outputs_decode.sample_id: [batch_size, max_time_step] # Or use argmax to find decoder symbols to emit: # self.decoder_pred_decode = tf.argmax(self.decoder_outputs_decode.rnn_output, # axis=-1, name='decoder_pred_decode') # Here, we use expand_dims to be compatible with the result of the beamsearch decoder # decoder_pred_decode: [batch_size, max_time_step, 1] (output_major=False) self.decoder_pred_decode = tf.expand_dims( self.decoder_outputs_decode.sample_id, -1) else: # Use beam search to approximately find the most likely translation # decoder_pred_decode: [batch_size, max_time_step, beam_width] (output_major=False) self.decoder_pred_decode = self.decoder_outputs_decode.predicted_ids
def _build_decoder(self, encoder_outputs, encoder_state, hparams): """Build and run a RNN decoder with a final projection layer. Args: encoder_outputs: The outputs of encoder for every time step. encoder_state: The final state of the encoder. hparams: The Hyperparameters configurations. Returns: A tuple of final logits and final decoder state: logits: size [time, batch_size, vocab_size] when time_major=True. """ tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) iterator = self.iterator # maximum_iteration: The maximum decoding steps. maximum_iterations = self._get_infer_maximum_iterations( hparams, iterator.source_sequence_length) ## Decoder. with tf.variable_scope("decoder") as decoder_scope: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, iterator.source_sequence_length) ## Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: # decoder_emp_inp: [max_time, batch_size, num_units] target_input = iterator.target_input if self.time_major: target_input = tf.transpose(target_input) decoder_emb_inp = tf.nn.embedding_lookup( self.embedding_decoder, target_input) # Helper helper = seq2seq.TrainingHelper( decoder_emb_inp, iterator.target_sequence_length, time_major=self.time_major) # Decoder my_decoder = seq2seq.BasicDecoder( cell, helper, decoder_initial_state, ) # Dynamic decoding outputs, final_context_state, \ _ = seq2seq.dynamic_decode( my_decoder, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) sample_id = outputs.sample_id # Note: there's a subtle difference here between train and # inference. # We could have set output_layer when create my_decoder # and shared more code between train and inference. # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per # timestep. logits = self.output_layer(outputs.rnn_output) ## Inference else: beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id if beam_width > 0: my_decoder = seq2seq.BeamSearchDecoder( cell=cell, embedding=self.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=length_penalty_weight) else: # Helper sampling_temperature = hparams.sampling_temperature if sampling_temperature > 0.0: helper = seq2seq.SampleEmbeddingHelper( self.embedding_decoder, start_tokens, end_token, softmax_temperature=sampling_temperature, seed=hparams.random_seed) else: helper = seq2seq.GreedyEmbeddingHelper( self.embedding_decoder, start_tokens, end_token) # Decoder my_decoder = seq2seq.BasicDecoder( cell, helper, decoder_initial_state, output_layer=self.output_layer # applied per timestep ) # Dynamic decoding outputs, final_context_state, \ _ = seq2seq.dynamic_decode( my_decoder, maximum_iterations=maximum_iterations, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) if beam_width > 0: # TODO rerank here logits = tf.no_op() sample_id = outputs.predicted_ids else: logits = outputs.rnn_output sample_id = outputs.sample_id return logits, sample_id, final_context_state
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size beam_width = config.beam_width GO_TOKEN = 0 EOS_TOKEN = 1 JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat), trainable=True) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper( cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper( cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper( cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper( cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper( cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper( cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper( cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper( cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), ((_, fw_h_f), (_, bw_h_f)) = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), ((_, fw_h_f), (_, bw_h_f)) = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell( cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell( cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell( cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell( cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) if config.na: na_bias = tf.get_variable("na_bias", shape=[], dtype='float') na_bias_tiled = tf.tile(tf.reshape(na_bias, [1, 1]), [N, 1]) # [N, 1] concat_flat_logits = tf.concat( axis=1, values=[na_bias_tiled, flat_logits]) concat_flat_yp = tf.nn.softmax(concat_flat_logits) na_prob = tf.squeeze(tf.slice(concat_flat_yp, [0, 0], [-1, 1]), [1]) flat_yp = tf.slice(concat_flat_yp, [0, 1], [-1, -1]) concat_flat_logits2 = tf.concat( axis=1, values=[na_bias_tiled, flat_logits2]) concat_flat_yp2 = tf.nn.softmax(concat_flat_logits2) na_prob2 = tf.squeeze( tf.slice(concat_flat_yp2, [0, 0], [-1, 1]), [1]) # [N] flat_yp2 = tf.slice(concat_flat_yp2, [0, 1], [-1, -1]) self.concat_logits = concat_flat_logits self.concat_logits2 = concat_flat_logits2 self.na_prob = na_prob * na_prob2 yp = tf.reshape(flat_yp, [-1, M, JX]) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) wyp = tf.nn.sigmoid(logits2) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 self.wyp = wyp with tf.variable_scope("q_gen"): # Question Generation Using (Paragraph & Predicted Ans Pos) NM = config.max_num_sents * config.batch_size # Separated encoder #ss = tf.reshape(xx, (-1, JX, dw+dco)) q_worthy = tf.reduce_sum( tf.to_int32(self.y), axis=2 ) # so we get probability distribution of answer-likely. (N, M) q_worthy = tf.expand_dims(tf.to_int32(tf.argmax(q_worthy, axis=1)), axis=1) # (N) -> (N, 1) q_worthy = tf.concat([ tf.expand_dims(tf.range(0, N, dtype=tf.int32), axis=1), q_worthy ], axis=1) # example : [0, 9], [1, 11], [2, 8], [3, 5], [4, 0], [5, 1] ... ss = tf.gather_nd(xx, q_worthy) syp = tf.expand_dims(tf.gather_nd(yp, q_worthy), axis=-1) syp2 = tf.expand_dims(tf.gather_nd(yp2, q_worthy), axis=-1) ss_with_ans = tf.concat([ss, syp, syp2], axis=2) qg_dim = 600 cell_fw, cell_bw = rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob), \ rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob) s_outputs, s_states = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, ss_with_ans, dtype=tf.float32) s_outputs = tf.concat(s_outputs, axis=2) s_states = tf.concat(s_states, axis=1) start_tokens = tf.zeros([N], dtype=tf.int32) self.inp_q_with_GO = tf.concat( [tf.expand_dims(start_tokens, axis=1), self.q], axis=1) # supervise if mode is train if config.mode == "train": emb_q = tf.nn.embedding_lookup(params=word_emb_mat, ids=self.inp_q_with_GO) #emb_q = tf.reshape(tf.tile(tf.expand_dims(emb_q, axis=1), [1, M, 1, 1]), (NM, JQ+1, dw)) train_helper = seq2seq.TrainingHelper(emb_q, [JQ] * N) else: s_outputs = seq2seq.tile_batch(s_outputs, multiplier=beam_width) s_states = seq2seq.tile_batch(s_states, multiplier=beam_width) cell = rnn.DropoutWrapper(rnn.GRUCell(num_units=qg_dim * 2), input_keep_prob=config.input_keep_prob) attention_mechanism = seq2seq.BahdanauAttention(num_units=qg_dim * 2, memory=s_outputs) attn_cell = seq2seq.AttentionWrapper(cell, attention_mechanism, attention_layer_size=qg_dim * 2, output_attention=True, alignment_history=False) total_glove_vocab_size = 78878 #72686 out_cell = rnn.OutputProjectionWrapper(attn_cell, VW + total_glove_vocab_size) if config.mode == "train": decoder_initial_states = out_cell.zero_state( batch_size=N, dtype=tf.float32).clone(cell_state=s_states) decoder = seq2seq.BasicDecoder( cell=out_cell, helper=train_helper, initial_state=decoder_initial_states) else: decoder_initial_states = out_cell.zero_state( batch_size=N * beam_width, dtype=tf.float32).clone(cell_state=s_states) decoder = seq2seq.BeamSearchDecoder( cell=out_cell, embedding=word_emb_mat, start_tokens=start_tokens, end_token=EOS_TOKEN, initial_state=decoder_initial_states, beam_width=beam_width, length_penalty_weight=0.0) outputs = seq2seq.dynamic_decode(decoder=decoder, maximum_iterations=JQ) if config.mode == "train": gen_q = outputs[0].sample_id gen_q_prob = outputs[0].rnn_output gen_q_states = outputs[1] else: gen_q = outputs[0].predicted_ids[:, :, 0] gen_q_prob = tf.nn.embedding_lookup( params=word_emb_mat, ids=outputs[0].predicted_ids[:, :, 0]) gen_q_states = outputs[1] self.gen_q = gen_q self.gen_q_prob = gen_q_prob self.gen_q_states = gen_q_states
def rbmE_gruD(mode, features, labels, params): inp = features["x"] if state != "Infering": ids = features["ids"] weights = features["weights"] batch_size = params["batch_size"] #Encoder enc_cell = rnn.NASCell(num_units=NUM_UNITS) enc_out, enc_state = tf.nn.dynamic_rnn(enc_cell, inp, time_major=False, dtype=tf.float32) #Decoder cell = rnn.NASCell(num_units=NUM_UNITS) _, embeddings = load_processed_embeddings(sess=tf.InteractiveSession()) out_lengths = tf.constant(seq_len, shape=[batch_size]) if state != "Infering": #sampling method for training train_helper = seq2seq.TrainingHelper(labels, out_lengths, time_major=False) ''' train_helper=seq2seq.ScheduledEmbeddingTrainingHelper(inputs=labels, sequence_length=out_lengths, embedding=embeddings, sampling_probability=probs) ''' #sampling method for evaluation start_tokens = tf.zeros([batch_size], dtype=tf.int32) infer_helper = seq2seq.GreedyEmbeddingHelper(embedding=embeddings, start_tokens=start_tokens, end_token=END) #infer_helper = seq2seq.SampleEmbeddingHelper(embeddings,start_tokens=start_tokens,end_token=END) #infer_helper=seq2seq.ScheduledEmbeddingTrainingHelper(inputs=inp,sequence_length=out_lengths,embedding=embeddings,sampling_probability=1.0) projection_layer = layers_core.Dense(vocab_size, use_bias=False) def decode(helper): decoder = seq2seq.BasicDecoder(cell=cell, helper=helper, initial_state=enc_state, output_layer=projection_layer) #decoder.tracks_own_finished=True (dec_outputs, _, _) = seq2seq.dynamic_decode(decoder, maximum_iterations=seq_len) #(dec_outputs,_,_) = seq2seq.dynamic_decode(decoder) dec_ids = dec_outputs.sample_id logits = dec_outputs.rnn_output return dec_ids, logits #equalize logits, labels and weight lengths incase of early finish in decoder def norm_logits_loss(logts, ids, weights): current_ts = tf.to_int32( tf.minimum(tf.shape(ids)[1], tf.shape(logts)[1])) logts = tf.slice(logts, begin=[0, 0, 0], size=[-1, current_ts, -1]) ids = tf.slice(ids, begin=[0, 0], size=[-1, current_ts]) weights = tf.slice(weights, begin=[0, 0], size=[-1, current_ts]) return logts, ids, weights #training mode if state == "Training": dec_ids, logits = decode(train_helper) # some sample_id are overwritten with '-1's #dec_ids = tf.argmax(logits, axis=2) tf.identity(dec_ids, name="predictions") logits, ids, weights = norm_logits_loss(logits, ids, weights) loss = tf.contrib.seq2seq.sequence_loss(logits, ids, weights=weights) learning_rate = 0.001 #0.0001 tf.identity(learning_rate, name="learning_rate") #evaluation mode if state == "Evaluating" or state == "Testing": eval_dec_ids, eval_logits = decode(infer_helper) #eval_dec_ids = tf.argmax(eval_logits, axis=2) tf.identity(eval_dec_ids, name="predictions") #equalize logits, labels and weight lengths incase of early finish in decoder eval_logits, ids, weights = norm_logits_loss(eval_logits, ids, weights) ''' current_ts = tf.to_int32(tf.minimum(tf.shape(ids)[1], tf.shape(eval_logits)[1])) ids = tf.slice(ids, begin=[0, 0], size=[-1, current_ts]) weights = tf.slice(weights, begin=[0, 0], size=[-1, current_ts]) #mask_ = tf.sequence_mask(lengths=target_sequence_length, maxlen=current_ts, dtype=eval_logits.dtype) eval_logits = tf.slice(eval_logits, begin=[0,0,0], size=[-1, current_ts, -1]) ''' eval_loss = tf.contrib.seq2seq.sequence_loss(eval_logits, ids, weights=weights) #beamSearch decoder init_state = tf.contrib.seq2seq.tile_batch(enc_state, multiplier=5) beamSearch_decoder = seq2seq.BeamSearchDecoder( cell, embeddings, start_tokens, end_token=END, initial_state=init_state, beam_width=5, output_layer=projection_layer) (infer_outputs, _, _) = seq2seq.dynamic_decode(beamSearch_decoder, maximum_iterations=seq_len) infer_ids = infer_outputs.predicted_ids infer_probs = infer_outputs.beam_search_decoder_output.scores infer_probs = tf.reduce_prod(infer_probs, axis=1) infer_pos = tf.argmax(infer_probs, axis=1) infers = {"ids": infer_ids, "pos": infer_pos} if mode == tf.estimator.ModeKeys.TRAIN: train_op = layers.optimize_loss(loss, tf.train.get_global_step(), optimizer='Adam', learning_rate=learning_rate, clip_gradients=5.0) spec = tf.estimator.EstimatorSpec(mode=mode, predictions=dec_ids, loss=loss, train_op=train_op) #evaluation mode elif mode == tf.estimator.ModeKeys.EVAL: spec = tf.estimator.EstimatorSpec(mode=mode, loss=eval_loss, predictions=eval_dec_ids) else: spec = tf.estimator.EstimatorSpec(mode=mode, predictions=infers) return spec
def __init__(self, vocab_size, hidden_size, dropout, num_layers, max_gradient_norm, batch_size, learning_rate, lr_decay_factor, max_target_length, max_source_length, decoder_mode=False): ''' vocab_size: number of vocab tokens buckets: buckets of max sequence lengths hidden_size: dimension of hidden layers num_layers: number of hidden layers max_gradient_norm: maximum gradient magnitude batch_size: number of training examples fed to network at once learning_rate: starting learning rate of network lr_decay_factor: amount by which to decay learning rate num_samples: number of samples for sampled softmax decoder_mode: Whether to build backpass nodes or not ''' GO_ID = config.GO_ID EOS_ID = config.EOS_ID self.max_source_length = max_source_length self.max_target_length = max_target_length self.vocab_size = vocab_size self.batch_size = batch_size self.global_step = tf.Variable(0, trainable=False) self.learning_rate = learning_rate self.encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs') self.source_lengths = tf.placeholder(shape=(None, ), dtype=tf.int32, name='source_lengths') self.decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets') self.target_lengths = tf.placeholder(shape=(None, ), dtype=tf.int32, name="target_lengths") with tf.variable_scope('embeddings') as scope: embeddings = tf.Variable(tf.random_uniform( [vocab_size, hidden_size], -1.0, 1.0), dtype=tf.float32) encoder_inputs_embedded = tf.nn.embedding_lookup( embeddings, self.encoder_inputs) targets_embedding = tf.nn.embedding_lookup(embeddings, self.decoder_targets) with tf.variable_scope('encoder') as scope: encoder_cell = rnn.LSTMCell(hidden_size) encoder_cell = rnn.DropoutWrapper(encoder_cell, input_keep_prob=dropout) encoder_cell = tf.nn.rnn_cell.MultiRNNCell( [encoder_cell for _ in range(num_layers)], state_is_tuple=True) encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=encoder_cell, cell_bw=encoder_cell, sequence_length=self.source_lengths, inputs=encoder_inputs_embedded, dtype=tf.float32, time_major=False) #BiLSTM encoder encoder_output = encoder_outputs[0] encoder_outputs = tf.concat(encoder_outputs, 2) with tf.variable_scope('decoder') as scope: decoder_cell = rnn.LSTMCell(hidden_size) decoder_cell = rnn.DropoutWrapper(decoder_cell, input_keep_prob=dropout) decoder_cell = tf.nn.rnn_cell.MultiRNNCell( [decoder_cell for _ in range(num_layers)], state_is_tuple=True) #TODO add attention #attention_mechanism= seq2seq.BahdanauAttention(num_units=hidden_size,memory=encoder_outputs) #decoder_cell = seq2seq.AttentionWrapper(cell=decoder_cell, # attention_mechanism=) attn_mech = seq2seq.BahdanauAttention( num_units=hidden_size, #depth of query mechanism memory=encoder_output, #out of RNN hidden states memory_sequence_length=self.source_lengths, name='BahdanauAttentiion') attn_cell = seq2seq.AttentionWrapper( cell=decoder_cell, #same as encoder attention_mechanism=attn_mech, attention_layer_size=hidden_size, #depth of attention tensor name='attention_wrapper') #attention layer if decoder_mode: beam_width = 1 attn_zero = attn_cell.zero_state(batch_size=(batch_size * beam_width), dtype=tf.float32) init_state = attn_zero.clone(cell_state=encoder_state) decoder = seq2seq.BeamSearchDecoder( cell=attn_cell, embedding=embeddings, start_tokens=tf.tile([GO_ID], [1]), end_token=EOS_ID, initial_state=init_state, beam_width=beam_width, output_layer=Dense(vocab_size)) #BeamSearch in Decoder final_outputs, final_state, final_sequence_lengths =\ seq2seq.dynamic_decode(decoder=decoder) self.logits = final_outputs.predicted_ids else: helper = seq2seq.TrainingHelper( inputs=targets_embedding, sequence_length=self.target_lengths) decoder = seq2seq.BasicDecoder( cell=attn_cell, helper=helper, #initial_state=attn_cell.zero_state(batch_size, tf.float32), initial_state=attn_cell.zero_state( batch_size, tf.float32).clone(cell_state=encoder_state[0]), output_layer=Dense(vocab_size)) final_outputs, final_state, final_sequence_lengths =\ seq2seq.dynamic_decode(decoder=decoder) self.logits = final_outputs.rnn_output if not decoder_mode: with tf.variable_scope("loss") as scope: #have to pad logits, dynamic decode produces results not consistent #in shape with targets pad_size = self.max_target_length - tf.reduce_max( final_sequence_lengths) self.logits = tf.pad(self.logits, [[0, 0], [0, pad_size], [0, 0]]) weights = tf.sequence_mask(lengths=final_sequence_lengths, maxlen=self.max_target_length, dtype=tf.float32, name='weights') x_entropy_loss = seq2seq.sequence_loss( logits=self.logits, targets=self.decoder_targets, weights=weights) #cross-entropy loss function self.loss = tf.reduce_mean(x_entropy_loss) optimizer = tf.train.AdamOptimizer() #Adam optimization algorithm gradients = optimizer.compute_gradients(x_entropy_loss) capped_grads = [(tf.clip_by_value(grad, -max_gradient_norm, max_gradient_norm), var) for grad, var in gradients] self.train_op = optimizer.apply_gradients( capped_grads, global_step=self.global_step) self.saver = tf.train.Saver(tf.global_variables())
def build_model(self): ''' 建立seq2seq模型 ''' self.query_input = tf.placeholder(tf.int32, [None, None]) self.query_length = tf.placeholder(tf.int32, [None]) self.answer_input = tf.placeholder(tf.int32, [None, None]) self.answer_target = tf.placeholder(tf.int32, [None, None]) self.answer_length = tf.placeholder(tf.int32, [None]) self.batch_size = array_ops.shape(self.query_input)[0] if self.mode == "train": self.max_decode_step = tf.reduce_max(self.answer_length) self.sequence_mask = tf.sequence_mask(self.answer_length, self.max_decode_step, dtype=tf.float32) elif self.mode == "decode": self.max_decode_step = tf.reduce_max(self.query_length) * 10 # input and output embedding self.embeddings_matrix = tf.Variable(tf.random_uniform( [self.vocab_size, EMBEDDING_SIZE], -1.0, 1.0), dtype=tf.float32) self.query_embeddings = tf.nn.embedding_lookup(self.embeddings_matrix, self.query_input) self.answer_embeddings = tf.nn.embedding_lookup( self.embeddings_matrix, self.answer_input) # encoder process self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn( rnn.BasicLSTMCell(ENCODER_HIDDEN_SIZE), self.query_embeddings, sequence_length=self.query_length, dtype=tf.float32) # 通过beam search 加工出一批临时变量,后续复用 batch_size, encoder_outputs, encoder_state, encoder_length = ( self.batch_size, self.encoder_outputs, self.encoder_state, self.query_length) if self.mode == "decode": batch_size = batch_size * BEAM_WIDTH encoder_outputs = seq2seq.tile_batch(t=self.encoder_outputs, multiplier=BEAM_WIDTH) encoder_state = nest.map_structure( lambda s: seq2seq.tile_batch(t=s, multiplier=BEAM_WIDTH), self.encoder_state) encoder_length = seq2seq.tile_batch(t=self.query_length, multiplier=BEAM_WIDTH) # attention wrapper self.attention_mechanism = seq2seq.BahdanauAttention( num_units=ENCODER_HIDDEN_SIZE, memory=encoder_outputs, memory_sequence_length=encoder_length) self.decoder_cell = seq2seq.AttentionWrapper( rnn.BasicLSTMCell(DECODER_HIDDEN_SIZE), attention_mechanism=self.attention_mechanism, attention_layer_size=ATTENTION_SIZE) self.decoder_initial_state = self.decoder_cell.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=encoder_state) self.decoder_dense = tf.layers.Dense( self.vocab_size, dtype=tf.float32, use_bias=False, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) # 如果是训练过程,使用training helper, 否则使用greedyhelper或beamsearch helper if self.mode == "train": training_helper = seq2seq.TrainingHelper( inputs=self.answer_embeddings, sequence_length=self.answer_length) training_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_initial_state, output_layer=self.decoder_dense) decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=training_decoder, impute_finished=True, maximum_iterations=self.max_decode_step) self.decoder_logits = tf.identity(decoder_outputs.rnn_output) self.loss = seq2seq.sequence_loss( logits=decoder_outputs.rnn_output, targets=self.answer_target, weights=self.sequence_mask) self.sample_ids = decoder_outputs.sample_id self.optimizer = tf.train.AdamOptimizer(LR_RATE) self.train_op = self.optimizer.minimize(self.loss) tf.summary.scalar('loss', self.loss) self.summary_op = tf.summary.merge_all() elif self.mode == "decode": start_tokens = tf.ones([self.batch_size], tf.int32) * self.go end_token = self.eos # 在beam search的情况下,给beam search helper传递的值,不需要使用BEAM_WIDTH的tensor # 此处使用beam_search/greedy helper解码都可以,如果只回复1条时等价 if USE_BEAMSEARCH: inference_decoder = seq2seq.BeamSearchDecoder( cell=self.decoder_cell, embedding=self.embeddings_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=self.decoder_initial_state, beam_width=BEAM_WIDTH, output_layer=self.decoder_dense) # 使用beam_search的时候,结果是predicted_ids, beam_search_decoder_output # predicted_ids: [batch_size, decoder_targets_length, beam_size] # beam_search_decoder_output: scores, predicted_ids, parent_ids decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=inference_decoder, maximum_iterations=self.max_decode_step) self.sample_ids = decoder_outputs.predicted_ids self.sample_ids = tf.transpose(self.sample_ids, perm=[0, 2, 1]) # 转置成行句子 else: decoding_helper = seq2seq.GreedyEmbeddingHelper( start_tokens=start_tokens, end_token=end_token, embedding=self.embeddings_matrix) inference_decoder = seq2seq.BasicDecoder( cell=self.decoder_cell, helper=decoding_helper, initial_state=self.decoder_initial_state, output_layer=self.decoder_dense) # 不使用beam_search的时候,结果是rnn_outputs, sample_id, # rnn_output: [batch_size, decoder_targets_length, vocab_size] # sample_id: [batch_size, decoder_targets_length], tf.int32 self.decoder_outputs_decode, self.final_state, _ = seq2seq.dynamic_decode( decoder=inference_decoder, maximum_iterations=self.max_decode_step) self.sample_ids = self.decoder_outputs_decode.sample_id
def sample(self, n, max_length=None, z=None, temperature=None, start_inputs=None, beam_width=None, end_token=None): """Overrides BaseLstmDecoder `sample` method to add optional beam search. Args: n: Scalar number of samples to return. max_length: (Optional) Scalar maximum sample length to return. Required if data representation does not include end tokens. z: (Optional) Latent vectors to sample from. Required if model is conditional. Sized `[n, z_size]`. temperature: (Optional) The softmax temperature to use when not doing beam search. Defaults to 1.0. Ignored when `beam_width` is provided. start_inputs: (Optional) Initial inputs to use for batch. Sized `[n, output_depth]`. beam_width: (Optional) Width of beam to use for beam search. Beam search is disabled if not provided. end_token: (Optional) Scalar token signaling the end of the sequence to use for early stopping. Returns: samples: Sampled sequences. Sized `[n, max_length, output_depth]`. Raises: ValueError: If `z` is provided and its first dimension does not equal `n`. """ if beam_width is None: end_fn = (None if end_token is None else lambda x: tf.equal(tf.argmax(x, axis=-1), end_token)) return super(CategoricalLstmDecoder, self).sample(n, max_length, z, temperature, start_inputs, end_fn) # If `end_token` is not given, use an impossible value. end_token = self._output_depth if end_token is None else end_token if z is not None and z.shape[0].value != n: raise ValueError( '`z` must have a first dimension that equals `n` when given. ' 'Got: %d vs %d' % (z.shape[0].value, n)) if temperature is not None: tf.logging.warning( '`temperature` is ignored when using beam search.') # Use a dummy Z in unconditional case. z = tf.zeros((n, 0), tf.float32) if z is None else z # If not given, start with dummy `-1` token and replace with zero vectors in # `embedding_fn`. start_tokens = (tf.argmax(start_inputs, axis=-1, output_type=tf.int32) if start_inputs is not None else -1 * tf.ones([n], dtype=tf.int32)) initial_state = initial_cell_state_from_embedding( self._dec_cell, z, name='decoder/z_to_initial_state') beam_initial_state = seq2seq.tile_batch(initial_state, multiplier=beam_width) # Tile `z` across beams. beam_z = tf.tile(tf.expand_dims(z, 1), [1, beam_width, 1]) def embedding_fn(tokens): # If tokens are the start_tokens (negative), replace with zero vectors. next_inputs = tf.cond( tf.less(tokens[0, 0], 0), lambda: tf.zeros([n, beam_width, self._output_depth]), lambda: tf.one_hot(tokens, self._output_depth)) # Concatenate `z` to next inputs. next_inputs = tf.concat([next_inputs, beam_z], axis=-1) return next_inputs decoder = seq2seq.BeamSearchDecoder(self._dec_cell, embedding_fn, start_tokens, end_token, beam_initial_state, beam_width, output_layer=self._output_layer, length_penalty_weight=0.0) final_output, _, _ = seq2seq.dynamic_decode( decoder, maximum_iterations=max_length, swap_memory=True, scope='decoder') return tf.one_hot(final_output.predicted_ids[:, :, 0], self._output_depth)
def build_model(self): """ build model :return: """ print('Building model...') # 1 定义模型的placeholder self.encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs') self.encoder_inputs_length = tf.placeholder( tf.int32, [None], name='encoder_inputs_length') self.batch_size = tf.placeholder(tf.int32, [], name='batch_size') self.keep_prob_dropout = tf.placeholder(tf.float32, name='keep_prob_dropout') self.decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets') self.decoder_targets_length = tf.placeholder( tf.int32, [None], name='decoder_targets_length') # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。 """ tf.sequence_mask(): tf.sequence_mask([1, 3, 2], 5) [[ True False False False False] [ True True True False False] [ True True False False False]] """ self.max_target_sequence_length = tf.reduce_max( self.decoder_targets_length, name='max_target_len') self.mask = tf.sequence_mask(self.decoder_targets_length, self.max_target_sequence_length, dtype=tf.float32, name='masks') # 2 定义模型的encoder部分 with tf.variable_scope('encoder'): # 创建LSTMCell,两层+dropout encoder_cell = self.create_rnn_cell() # 构建Embedding矩阵,encoder和decoder共用该词向量矩阵 # embedding.shape = (vocab_size, embedding_size) # encoder_inputs.shape = (batch_size, encoder_inputs_length) # encoder_inputs_embedded.shape = (batch_size, encoder_inputs_length, embedding_size) embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size]) encoder_inputs_embedded = tf.nn.embedding_lookup( embedding, self.encoder_inputs) # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量 # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size # encoder_state用于decoder的初始状态,batch_size*rnn_size encoder_outputs, encoder_state = tf.nn.dynamic_rnn( encoder_cell, encoder_inputs_embedded, sequence_length=self.encoder_inputs_length, dtype=tf.float32) # 3 定义模型的decoder部分 with tf.variable_scope('decoder'): encoder_inputs_length = self.encoder_inputs_length if self.beam_search: # 如果使用beam_search,则需要将encoder的输出进行tile_batch,复制beam_size份 print("use beam search decoding...") encoder_outputs = seq2seq.tile_batch(encoder_outputs, multiplier=self.beam_size) encoder_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_size), encoder_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_size) # 定义要使用的attention机制 attention_mechanism = seq2seq.BahdanauAttention( num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # attention_mechanism = seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) # 定义decoder阶段要使用的LSTMCell,然后为其封装attention wrapper decoder_cell = self.create_rnn_cell() decoder_cell = seq2seq.AttentionWrapper( cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper') # 如果使用beam_search则batch_size = self.batch_size * self.beam_size batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size # 定义decoder阶段的初始状态,直接使用encoder阶段的最后一个隐层状态进行赋值 decoder_initial_state = decoder_cell.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=encoder_state) output_layer = tf.layers.Dense( self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) if self.mode == 'train': # 定义decoder阶段的输入,其实就是在decoder的target开始处添加一个<go>,并删除结尾处的<end>,并进行embedding # decoder_inputs_embedded的shape为[batch_size,decoder_targets_length,embedding_size] ending = tf.strided_slice(self.decoder_targets, [0, 0], [self.batch_size, -1], [1, 1]) decoder_inputs = tf.concat([ tf.fill([self.batch_size, 1], tf.cast(self.word2id[self.goToken], dtype=tf.int32)), ending ], 1) decoder_inputs_embedded = tf.nn.embedding_lookup( embedding, decoder_inputs) # 训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能 training_helper = seq2seq.TrainingHelper( inputs=decoder_inputs_embedded, sequence_length=self.decoder_targets_length, time_major=False, name='training_helper') training_decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=training_helper, initial_state=decoder_initial_state, output_layer=output_layer) # 调用dynamic_decoder进行解码,decoder_outputs的一个named tuple,里面包含两项(rnn_outputs, sample_id) # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decoder每个时刻每个单词的概率,可以用来计算loss # sample_id: [batch_size], tf.int32, 保存最终的编码结果,可以表示最后的答案 decoder_outputs, _, _ = seq2seq.dynamic_decode( decoder=training_decoder, impute_finished=True, maximum_iterations=self.max_target_sequence_length) # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op self.decoder_logits_train = tf.identity( decoder_outputs.rnn_output) self.decoder_predict_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_pred_train') # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志 self.loss = seq2seq.sequence_loss( logits=self.decoder_logits_train, targets=self.decoder_targets, weights=self.mask) # Training summary for the current batch_loss tf.summary.scalar('loss', self.loss) self.summary_op = tf.summary.merge_all() optimizer = tf.train.AdamOptimizer(self.learning_rate) trainable_params = tf.trainable_variables() gradients = tf.gradients(self.loss, trainable_params) clip_gradients, _ = tf.clip_by_global_norm( gradients, self.max_gradient_norm) self.train_op = optimizer.apply_gradients( zip(clip_gradients, trainable_params)) elif self.mode == 'predict': start_tokens = tf.ones([ self.batch_size, ], tf.int32) * self.word2id[self.goToken] end_token = self.word2id[self.endToken] # decoder阶段根据是否使用beam_search决定不同的组合 # 如果使用则直接调用BeamSearchDecoder(里面已经实现了helper类) # 如果不使用则调用GreedyEmbeddingHelper+BasicDecoder的组合进行贪婪式解码 if self.beam_search: inference_decoder = seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=self.beam_size, output_layer=output_layer) else: decoder_helper = seq2seq.GreedyEmbeddingHelper( embedding=embedding, start_tokens=start_tokens, end_token=end_token) inference_decoder = seq2seq.BasicDecoder( cell=decoder_cell, helper=decoder_helper, initial_state=decoder_initial_state, output_layer=output_layer) decoder_outputs, _, _ = seq2seq.dynamic_decode( decoder=inference_decoder, maximum_iterations=10) # 调用dynamic_decoder进行编码,decoder_outputs是一个named tuple # 对于不使用beam_search的时候,它里面包含两项(rnn_outputs,sample_id) # rnn_output: [batch_size, decoder_targets_length, vocab_size] # sample_id: [batch_size, decoder_targets_length], tf.int32 # 对于使用beam search的时候,它里面包含两项(predicted_ids, beam_search_decoder_outputs) # predicted_ids: [batch_size, decoder_targets_length, beam_size], 保存输出结果 # beam_search_decoder_output: BeamSearchDecoderOutput instance named tuple (scores, predicted_ids, parent_ids) # 所以对应只需要返回predicted_ids或者sample_id即可翻译成最终结果 if self.beam_search: self.decoder_predict_decoder = decoder_outputs.predicted_ids else: self.decoder_predict_decoder = tf.expand_dims( decoder_outputs.sample_id, -1) # 4 保存模型 self.saver = tf.train.Saver(tf.global_variables())
def call(self, inputs, training=None, mask=None): dec_emb_fn = lambda ids: self.embed(ids) if self.is_infer: enc_outputs, enc_state, enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] helper = seq2seq.GreedyEmbeddingHelper(embedding=dec_emb_fn, start_tokens=tf.fill([batch_size], self.dec_start_id), end_token=self.dec_end_id) else: dec_inputs, dec_seq_len, enc_outputs, enc_state, \ enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] dec_inputs = self.embed(dec_inputs) helper = seq2seq.TrainingHelper(inputs=dec_inputs, sequence_length=dec_seq_len) if self.is_infer and self.beam_size > 1: tiled_enc_outputs = seq2seq.tile_batch(enc_outputs, multiplier=self.beam_size) tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size) attn_mech = self._build_attention(enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len) dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech) tiled_enc_last_state = seq2seq.tile_batch(enc_state, multiplier=self.beam_size) tiled_dec_init_state = dec_cell.zero_state(batch_size=batch_size * self.beam_size, dtype=tf.float32) if self.initial_decode_state: tiled_dec_init_state = tiled_dec_init_state.clone(cell_state=tiled_enc_last_state) dec = seq2seq.BeamSearchDecoder(cell=dec_cell, embedding=dec_emb_fn, start_tokens=tf.tile([self.dec_start_id], [batch_size]), end_token=self.dec_end_id, initial_state=tiled_dec_init_state, beam_width=self.beam_size, output_layer=tf.layers.Dense(self.vocab_size), length_penalty_weight=self.length_penalty) else: attn_mech = self._build_attention(enc_outputs=enc_outputs, enc_seq_len=enc_seq_len) dec_cell = seq2seq.AttentionWrapper(cell=self.cell, attention_mechanism=attn_mech) dec_init_state = dec_cell.zero_state(batch_size=batch_size, dtype=tf.float32) if self.initial_decode_state: dec_init_state = dec_init_state.clone(cell_state=enc_state) dec = seq2seq.BasicDecoder(cell=dec_cell, helper=helper, initial_state=dec_init_state, output_layer=tf.layers.Dense(self.vocab_size)) if self.is_infer: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=self.max_dec_len, swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.predicted_ids[:, :, 0] else: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=tf.reduce_max(dec_seq_len), swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.rnn_output
def __call__(self, top_k_attributes, mean_image_features=None, mean_object_features=None, spatial_image_features=None, spatial_object_features=None, seq_inputs=None, lengths=None ): assert(mean_image_features is not None or mean_object_features is not None or spatial_image_features is not None or spatial_object_features is not None) attribute_features = tf.nn.embedding_lookup(self.attribute_embeddings_map, top_k_attributes) mean_attribute_features = tf.reduce_mean(attribute_features, [1]) use_beam_search = (seq_inputs is None or lengths is None) if mean_image_features is not None: batch_size = tf.shape(mean_image_features)[0] mean_image_features = tf.concat([mean_image_features, mean_attribute_features], 1) elif mean_object_features is not None: batch_size = tf.shape(mean_object_features)[0] mean_object_features = tf.concat([mean_object_features, attribute_features], 1) elif spatial_image_features is not None: batch_size = tf.shape(spatial_image_features)[0] spatial_image_features = collapse_dims(spatial_image_features, [1, 2]) mean_image_features = tf.concat([tf.reduce_mean(spatial_image_features, [1]), mean_attribute_features], 1) spatial_image_features = tf.concat([spatial_image_features, attribute_features], 1) elif spatial_object_features is not None: batch_size = tf.shape(spatial_object_features)[0] spatial_object_features = collapse_dims(spatial_object_features, [2, 3]) mean_object_features = tf.concat([tf.reduce_mean(spatial_object_features, [2]), attribute_features], 1) spatial_object_features = tf.concat([spatial_object_features, tf.expand_dims(attribute_features, 2)], 2) initial_state = self.image_caption_cell.zero_state(batch_size, tf.float32) if use_beam_search: if mean_image_features is not None: mean_image_features = seq2seq.tile_batch(mean_image_features, multiplier=self.beam_size) self.image_caption_cell.mean_image_features = mean_image_features if mean_object_features is not None: mean_object_features = seq2seq.tile_batch(mean_object_features, multiplier=self.beam_size) self.image_caption_cell.mean_object_features = mean_object_features if spatial_image_features is not None: spatial_image_features = seq2seq.tile_batch(spatial_image_features, multiplier=self.beam_size) self.image_caption_cell.spatial_image_features = spatial_image_features if spatial_object_features is not None: spatial_object_features = seq2seq.tile_batch(spatial_object_features, multiplier=self.beam_size) self.image_caption_cell.spatial_object_features = spatial_object_features initial_state = seq2seq.tile_batch(initial_state, multiplier=self.beam_size) decoder = seq2seq.BeamSearchDecoder(self.image_caption_cell, self.word_embeddings_map, tf.fill([batch_size], self.word_vocabulary.start_id), self.word_vocabulary.end_id, initial_state, self.beam_size, output_layer=self.word_logits_layer) outputs, state, lengths = seq2seq.dynamic_decode(decoder, maximum_iterations=self.maximum_iterations) ids = tf.transpose(outputs.predicted_ids, [0, 2, 1]) sequence_length = tf.shape(ids)[2] flat_ids = tf.reshape(ids, [batch_size * self.beam_size, sequence_length]) seq_inputs = tf.concat([ tf.fill([batch_size * self.beam_size, 1], self.word_vocabulary.start_id), flat_ids], 1) if mean_image_features is not None: self.image_caption_cell.mean_image_features = mean_image_features if mean_object_features is not None: self.image_caption_cell.mean_object_features = mean_object_features if spatial_image_features is not None: self.image_caption_cell.spatial_image_features = spatial_image_features if spatial_object_features is not None: self.image_caption_cell.spatial_object_features = spatial_object_features activations, _state = tf.nn.dynamic_rnn(self.image_caption_cell, tf.nn.embedding_lookup(self.word_embeddings_map, seq_inputs), sequence_length=tf.reshape(lengths, [-1]), initial_state=initial_state) logits = self.word_logits_layer(activations) if use_beam_search: length = tf.shape(logits)[1] logits = tf.reshape(logits, [batch_size, self.beam_size, length, self.vocab_size]) return logits, tf.argmax(logits, axis=-1, output_type=tf.int32)