def add_decoder_for_training(self): self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell( [self.single_cell() for _ in range(1 * self.n_layers)]) self.decoder_cell = SelfAttWrapper( self.decoder_cell, self.init_attention, self.init_memory, att_layer=core_layers.Dense(self.rnn_size, name='att_dense'), att_type=self.att_type) decoder_embedding = tf.get_variable( 'word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( inputs=tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), sequence_length=self.X_seq_len, embedding=decoder_embedding, sampling_probability=1 - self.force_teaching_ratio, time_major=False) training_decoder = tf.contrib.seq2seq.BasicDecoder( cell=self.decoder_cell, helper=training_helper, initial_state=self.decoder_cell.zero_state( self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), output_layer=core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( decoder=training_decoder, impute_finished=True, maximum_iterations=tf.reduce_max(self.X_seq_len)) self.training_logits = training_decoder_output.rnn_output self.init_prefix_state = training_final_state
def add_decoder_for_prefix_sample(self): self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell( [self.single_cell() for _ in range(1 * self.n_layers)]) self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer=core_layers.Dense( self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) word_embedding = tf.get_variable('word_embedding') prefix_sample_helper = my_helper.MyHelper( inputs=self.processed_decoder_input(), sequence_length=self.X_seq_len, embedding=word_embedding, end_token=self._x_eos) sample_prefix_decoder = tf.contrib.seq2seq.BasicDecoder( cell=self.decoder_cell, helper=prefix_sample_helper, initial_state=self.decoder_cell.zero_state( self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), output_layer=core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True)) sample_decoder_prefix_output, self.sample_prefix_final_state, _ = tf.contrib.seq2seq.dynamic_decode( decoder=sample_prefix_decoder, impute_finished=False, maximum_iterations=self.max_infer_length) self.sample_prefix_output = sample_decoder_prefix_output.sample_id
def add_decoder_for_prefix_inference(self): self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention_tiled, self.init_memory_tiled, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True),att_type=self.att_type) self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) my_decoder = DiverseDecode.BeamSearchDecoder( cell = self.decoder_cell, embedding = tf.get_variable('word_embedding'), start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), end_token = self._x_eos, gamma = self.gamma, initial_state = self.beam_init_state, beam_width = self.beam_width, vocab_size = len(self.dp.X_w2id), output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), length_penalty_weight = self.beam_penalty) self.prefix_go = tf.placeholder(tf.int32, [None]) prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) my_decoder._start_inputs = prefix_emb predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder = my_decoder, impute_finished = False, maximum_iterations = self.max_infer_length) self.prefix_infer_outputs = predicting_decoder_output.predicted_ids self.score = predicting_decoder_output.beam_search_decoder_output.scores