コード例 #1
0
 def add_decoder_for_training(self):
     self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell(
         [self.single_cell() for _ in range(1 * self.n_layers)])
     self.decoder_cell = SelfAttWrapper(
         self.decoder_cell,
         self.init_attention,
         self.init_memory,
         att_layer=core_layers.Dense(self.rnn_size, name='att_dense'),
         att_type=self.att_type)
     decoder_embedding = tf.get_variable(
         'word_embedding',
         [len(self.dp.X_w2id), self.decoder_embedding_dim], tf.float32,
         tf.random_uniform_initializer(-1.0, 1.0))
     training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
         inputs=tf.nn.embedding_lookup(decoder_embedding,
                                       self.processed_decoder_input()),
         sequence_length=self.X_seq_len,
         embedding=decoder_embedding,
         sampling_probability=1 - self.force_teaching_ratio,
         time_major=False)
     training_decoder = tf.contrib.seq2seq.BasicDecoder(
         cell=self.decoder_cell,
         helper=training_helper,
         initial_state=self.decoder_cell.zero_state(
             self.batch_size,
             tf.float32),  #.clone(cell_state=self.encoder_state),
         output_layer=core_layers.Dense(len(self.dp.X_w2id),
                                        name='output_dense'))
     training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
         decoder=training_decoder,
         impute_finished=True,
         maximum_iterations=tf.reduce_max(self.X_seq_len))
     self.training_logits = training_decoder_output.rnn_output
     self.init_prefix_state = training_final_state
コード例 #2
0
 def add_decoder_for_prefix_sample(self):
     self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell(
         [self.single_cell() for _ in range(1 * self.n_layers)])
     self.decoder_cell = SelfAttWrapper(self.decoder_cell,
                                        self.init_attention,
                                        self.init_memory,
                                        att_layer=core_layers.Dense(
                                            self.rnn_size,
                                            name='att_dense',
                                            _reuse=True),
                                        att_type=self.att_type)
     word_embedding = tf.get_variable('word_embedding')
     prefix_sample_helper = my_helper.MyHelper(
         inputs=self.processed_decoder_input(),
         sequence_length=self.X_seq_len,
         embedding=word_embedding,
         end_token=self._x_eos)
     sample_prefix_decoder = tf.contrib.seq2seq.BasicDecoder(
         cell=self.decoder_cell,
         helper=prefix_sample_helper,
         initial_state=self.decoder_cell.zero_state(
             self.batch_size,
             tf.float32),  #.clone(cell_state=self.encoder_state),
         output_layer=core_layers.Dense(len(self.dp.X_w2id),
                                        name='output_dense',
                                        _reuse=True))
     sample_decoder_prefix_output, self.sample_prefix_final_state, _ = tf.contrib.seq2seq.dynamic_decode(
         decoder=sample_prefix_decoder,
         impute_finished=False,
         maximum_iterations=self.max_infer_length)
     self.sample_prefix_output = sample_decoder_prefix_output.sample_id
コード例 #3
0
 def add_decoder_for_prefix_inference(self):
     self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)])
     self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width)
     self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width)
     
     self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention_tiled, self.init_memory_tiled, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True),att_type=self.att_type)
     self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width)
     my_decoder = DiverseDecode.BeamSearchDecoder(
         cell = self.decoder_cell,
         embedding = tf.get_variable('word_embedding'),
         start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]),
         end_token = self._x_eos,
         gamma = self.gamma,
         initial_state = self.beam_init_state,
         beam_width = self.beam_width,
         vocab_size = len(self.dp.X_w2id),
         output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True),
         length_penalty_weight = self.beam_penalty)
     
     self.prefix_go = tf.placeholder(tf.int32, [None])
     prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width])
     prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam)
     my_decoder._start_inputs = prefix_emb
     predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
         decoder = my_decoder,
         impute_finished = False,
         maximum_iterations = self.max_infer_length)
     self.prefix_infer_outputs = predicting_decoder_output.predicted_ids
     self.score = predicting_decoder_output.beam_search_decoder_output.scores