def test_decode_train(self): """Tests decoding in training mode. """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = tf.constant(seq_length) hparams = { "attention": { "kwargs": { "num_units": self._attention_dim, # Note: to use sparsemax in TF-CPU, it looks # `memory_sequence_length` must equal max_time. # "probability_fn": "sparsemax" } } } decoder = AttentionRNNDecoder( memory=self._encoder_output, memory_sequence_length=encoder_values_length, vocab_size=self._vocab_size, hparams=hparams) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=self._inputs, sequence_length=[self._max_time] * self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_train) # 4+1 trainable variables: cell-kernel, cell-bias, # fc-weight, fc-bias, and # memory_layer: For LuongAttention, we only transform the memory layer; # thus num_units *must* match the expected query depth. self.assertEqual(len(decoder.trainable_variables), 5) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={context.global_mode(): tf.estimator.ModeKeys.TRAIN}) self.assertIsInstance(outputs_, AttentionRNNDecoderOutput) self.assertEqual( outputs_.logits.shape, (self._batch_size, self._max_time, self._vocab_size)) self.assertEqual(outputs_.sample_id.shape, (self._batch_size, self._max_time)) self.assertEqual(final_state_.cell_state[0].shape, (self._batch_size, cell_dim)) np.testing.assert_array_equal(sequence_lengths_, [self._max_time] * self._batch_size)
def test_decode_infer(self): """Tests decoding in inference mode. """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = tf.constant(seq_length) hparams = { "attention": { "kwargs": { "num_units": 256, } } } decoder = AttentionRNNDecoder( vocab_size=self._vocab_size, memory=self._encoder_output, memory_sequence_length=encoder_values_length, hparams=hparams) helper_infer = get_helper( decoder.hparams.helper_infer.type, embedding=self._embedding, start_tokens=[1] * self._batch_size, end_token=2, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_infer) # 4+1 trainable variables: cell-kernel, cell-bias, # fc-weight, fc-bias, and # memory_layer: For LuongAttention, we only transform the memory layer; # thus num_units *must* match the expected query depth. self.assertEqual(len(decoder.trainable_variables), 5) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.PREDICT }) self.assertIsInstance(outputs_, AttentionRNNDecoderOutput) max_length = max(sequence_lengths_) self.assertEqual(outputs_.logits.shape, (self._batch_size, max_length, self._vocab_size)) self.assertEqual(outputs_.sample_id.shape, (self._batch_size, max_length)) self.assertEqual(final_state_.cell_state[0].shape, (self._batch_size, cell_dim))
def test_beam_search_cell(self): """Tests :meth:`texar.tf.modules.AttentionRNNDecoder._get_beam_search_cell` """ seq_length = np.random.randint(self._max_time, size=[self._batch_size ]) + 1 encoder_values_length = tf.constant(seq_length) hparams = { "attention": { "kwargs": { "num_units": self._attention_dim, "probability_fn": "sparsemax" } } } decoder = AttentionRNNDecoder( memory=self._encoder_output, memory_sequence_length=encoder_values_length, vocab_size=self._vocab_size, hparams=hparams) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=self._inputs, sequence_length=[self._max_time] * self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) _, _, _ = decoder(helper=helper_train) # 4+1 trainable variables: cell-kernel, cell-bias, # fc-weight, fc-bias, and # memory_layer: For LuongAttention, we only transform the memory layer; # thus num_units *must* match the expected query depth. self.assertEqual(len(decoder.trainable_variables), 5) beam_width = 3 beam_cell = decoder._get_beam_search_cell(beam_width) cell_input = tf.random_uniform( [self._batch_size * beam_width, self._emb_dim]) cell_state = beam_cell.zero_state(self._batch_size * beam_width, tf.float32) _ = beam_cell(cell_input, cell_state) # Test if beam_cell is sharing variables with decoder cell. for tvar in beam_cell.trainable_variables: self.assertTrue(tvar in decoder.trainable_variables)
def test_decode_train(self): """Tests decoding in training mode. """ output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=self._inputs, sequence_length=[self._max_time] * self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_train) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( inputs=self._inputs, sequence_length=[self._max_time] * self._batch_size) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( decoding_strategy=None, inputs=self._inputs, sequence_length=[self._max_time] * self._batch_size) self._test_outputs(decoder, outputs, final_state, sequence_lengths) outputs, final_state, sequence_lengths = decoder( decoding_strategy=None, embedding=self._embedding, start_tokens=[1] * self._batch_size, end_token=2, mode=tf.estimator.ModeKeys.EVAL) self._test_outputs(decoder, outputs, final_state, sequence_lengths, test_mode=True)
def test_decode_infer(self): """Tests decoding in inferencee mode. """ output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_infer = get_helper( decoder.hparams.helper_infer.type, embedding=self._embedding, start_tokens=[self._vocab_size - 2] * self._batch_size, end_token=self._vocab_size - 1, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_infer) # 4 trainable variables: embedding, cell-kernel, cell-bias, # fc-layer-weights, fc-layer-bias self.assertEqual(len(decoder.trainable_variables), 4) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.PREDICT }) self.assertIsInstance(outputs_, BasicRNNDecoderOutput) max_length = max(sequence_lengths_) self.assertEqual(outputs_.logits.shape, (self._batch_size, max_length, self._vocab_size)) self.assertEqual(outputs_.sample_id.shape, (self._batch_size, max_length)) self.assertEqual(final_state_[0].shape, (self._batch_size, cell_dim))
def test_decode_train_with_tf(self): """Compares decoding results with TF built-in decoder. """ _inputs_placeholder = tf.placeholder( tf.int32, [self._batch_size, self._max_time], name="inputs") _embedding_placeholder = tf.placeholder( tf.float32, [self._vocab_size, self._emb_dim], name="emb") inputs = tf.nn.embedding_lookup(_embedding_placeholder, _inputs_placeholder) output_layer = tf.layers.Dense(self._vocab_size) decoder = BasicRNNDecoder(vocab_size=self._vocab_size, output_layer=output_layer) helper_train = get_helper( decoder.hparams.helper_train.type, inputs=inputs, sequence_length=[self._max_time] * self._batch_size, **decoder.hparams.helper_train.kwargs.todict()) outputs, final_state, sequence_lengths = decoder(helper=helper_train) tf_helper = tf.contrib.seq2seq.TrainingHelper( inputs, [self._max_time] * self._batch_size) tf_decoder = tf.contrib.seq2seq.BasicDecoder(decoder.cell, tf_helper, decoder.cell.zero_state( self._batch_size, tf.float32), output_layer=output_layer) tf_outputs, tf_final_state, tf_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode(tf_decoder) cell_dim = decoder.hparams.rnn_cell.kwargs.num_units with self.test_session() as sess: sess.run(tf.global_variables_initializer()) inputs_ = np.random.randint(self._vocab_size, size=(self._batch_size, self._max_time), dtype=np.int32) embedding_ = np.random.randn(self._vocab_size, self._emb_dim) outputs_, final_state_, sequence_lengths_ = sess.run( [outputs, final_state, sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.TRAIN, _inputs_placeholder: inputs_, _embedding_placeholder: embedding_ }) self.assertEqual(final_state_[0].shape, (self._batch_size, cell_dim)) tf_outputs_, tf_final_state_, tf_sequence_lengths_ = sess.run( [tf_outputs, tf_final_state, tf_sequence_lengths], feed_dict={ context.global_mode(): tf.estimator.ModeKeys.TRAIN, _inputs_placeholder: inputs_, _embedding_placeholder: embedding_ }) np.testing.assert_array_equal(outputs_.logits, tf_outputs_.rnn_output) np.testing.assert_array_equal(outputs_.sample_id, tf_outputs_.sample_id) np.testing.assert_array_equal(final_state_.c, tf_final_state_.c) np.testing.assert_array_equal(final_state_.h, tf_final_state_.h) np.testing.assert_array_equal(sequence_lengths_, tf_sequence_lengths_)
def _build(self, decoding_strategy="train_greedy", initial_state=None, inputs=None, memory=None, sequence_length=None, embedding=None, start_tokens=None, end_token=None, softmax_temperature=None, max_decoding_length=None, impute_finished=False, output_time_major=False, input_time_major=False, helper=None, mode=None, **kwargs): # Memory for _mechanism in self._cell._attention_mechanisms: _mechanism.initialize_memory(memory) # Helper if helper is not None: pass elif decoding_strategy is not None: if decoding_strategy == "train_greedy": helper = rnn_decoder_helpers._get_training_helper( inputs, sequence_length, embedding, input_time_major) elif decoding_strategy == "infer_greedy": helper = tx_helper.GreedyEmbeddingHelper( embedding, start_tokens, end_token) elif decoding_strategy == "infer_sample": helper = tx_helper.SampleEmbeddingHelper( embedding, start_tokens, end_token, softmax_temperature) else: raise ValueError( "Unknown decoding strategy: {}".format(decoding_strategy)) else: if is_train_mode_py(mode): kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict()) helper_type = self._hparams.helper_train.type else: kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict()) helper_type = self._hparams.helper_infer.type kwargs_.update({ "inputs": inputs, "sequence_length": sequence_length, "time_major": input_time_major, "embedding": embedding, "start_tokens": start_tokens, "end_token": end_token, "softmax_temperature": softmax_temperature}) kwargs_.update(kwargs) helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_) self._helper = helper # Initial state if initial_state is not None: self._initial_state = initial_state else: self._initial_state = self.zero_state( batch_size=self.batch_size, dtype=tf.float32) # Maximum decoding length max_l = max_decoding_length if max_l is None: max_l_train = self._hparams.max_decoding_length_train if max_l_train is None: max_l_train = utils.MAX_SEQ_LENGTH max_l_infer = self._hparams.max_decoding_length_infer if max_l_infer is None: max_l_infer = utils.MAX_SEQ_LENGTH max_l = tf.cond(is_train_mode(mode), lambda: max_l_train, lambda: max_l_infer) self.max_decoding_length = max_l # Decode outputs, final_state, sequence_lengths = dynamic_decode( decoder=self, impute_finished=impute_finished, maximum_iterations=max_l, output_time_major=output_time_major) if not self._built: self._add_internal_trainable_variables() # Add trainable variables of `self._cell` which may be # constructed externally. self._add_trainable_variable( layers.get_rnn_cell_trainable_variables(self._cell)) if isinstance(self._output_layer, tf.layers.Layer): self._add_trainable_variable( self._output_layer.trainable_variables) # Add trainable variables of `self._beam_search_rnn_cell` which # may already be constructed and used. if self._beam_search_cell is not None: self._add_trainable_variable( self._beam_search_cell.trainable_variables) self._built = True return outputs, final_state, sequence_lengths
def _build(self, decoding_strategy="train_greedy", initial_state=None, inputs=None, sequence_length=None, embedding=None, start_tokens=None, end_token=None, softmax_temperature=None, max_decoding_length=None, impute_finished=False, output_time_major=False, input_time_major=False, helper=None, mode=None, **kwargs): """Performs decoding. This is a shared interface for both :class:`~texar.tf.modules.BasicRNNDecoder` and :class:`~texar.tf.modules.AttentionRNNDecoder`. The function provides **3 ways** to specify the decoding method, with varying flexibility: 1. The :attr:`decoding_strategy` argument: A string taking value of: - **"train_greedy"**: decoding in teacher-forcing fashion \ (i.e., feeding \ `ground truth` to decode the next step), and each sample is \ obtained by taking the `argmax` of the RNN output logits. \ Arguments :attr:`(inputs, sequence_length, input_time_major)` \ are required for this strategy, and argument :attr:`embedding` \ is optional. - **"infer_greedy"**: decoding in inference fashion (i.e., feeding \ the `generated` sample to decode the next step), and each sample\ is obtained by taking the `argmax` of the RNN output logits.\ Arguments :attr:`(embedding, start_tokens, end_token)` are \ required for this strategy, and argument \ :attr:`max_decoding_length` is optional. - **"infer_sample"**: decoding in inference fashion, and each sample is obtained by `random sampling` from the RNN output distribution. Arguments \ :attr:`(embedding, start_tokens, end_token)` are \ required for this strategy, and argument \ :attr:`max_decoding_length` is optional. This argument is used only when argument :attr:`helper` is `None`. Example: .. code-block:: python embedder = WordEmbedder(vocab_size=data.vocab.size) decoder = BasicRNNDecoder(vocab_size=data.vocab.size) # Teacher-forcing decoding outputs_1, _, _ = decoder( decoding_strategy='train_greedy', inputs=embedder(data_batch['text_ids']), sequence_length=data_batch['length']-1) # Random sample decoding. Gets 100 sequence samples outputs_2, _, sequence_length = decoder( decoding_strategy='infer_sample', start_tokens=[data.vocab.bos_token_id]*100, end_token=data.vocab.eos.token_id, embedding=embedder, max_decoding_length=60) 2. The :attr:`helper` argument: An instance of subclass of \ :class:`texar.tf.modules.Helper`. This provides a superset of decoding strategies than above, for example: - :class:`~texar.tf.modules.TrainingHelper` corresponding to the \ "train_greedy" strategy. - :class:`~texar.tf.modules.GreedyEmbeddingHelper` and \ :class:`~texar.tf.modules.SampleEmbeddingHelper` corresponding to \ the "infer_greedy" and "infer_sample", respectively. - :class:`~texar.tf.modules.TopKSampleEmbeddingHelper` for Top-K \ sample decoding. - :class:`ScheduledEmbeddingTrainingHelper` and \ :class:`ScheduledOutputTrainingHelper` for scheduled \ sampling. - :class:`~texar.tf.modules.SoftmaxEmbeddingHelper` and \ :class:`~texar.tf.modules.GumbelSoftmaxEmbeddingHelper` for \ soft decoding and gradient backpropagation. Helpers give the maximal flexibility of configuring the decoding\ strategy. Example: .. code-block:: python embedder = WordEmbedder(vocab_size=data.vocab.size) decoder = BasicRNNDecoder(vocab_size=data.vocab.size) # Teacher-forcing decoding, same as above with # `decoding_strategy='train_greedy'` helper_1 = tx.modules.TrainingHelper( inputs=embedders(data_batch['text_ids']), sequence_length=data_batch['length']-1) outputs_1, _, _ = decoder(helper=helper_1) # Gumbel-softmax decoding helper_2 = GumbelSoftmaxEmbeddingHelper( embedding=embedder, start_tokens=[data.vocab.bos_token_id]*100, end_token=data.vocab.eos_token_id, tau=0.1) outputs_2, _, sequence_length = decoder( max_decoding_length=60, helper=helper_2) 3. :attr:`hparams["helper_train"]` and :attr:`hparams["helper_infer"]`:\ Specifying the helper through hyperparameters. Train and infer \ strategy is toggled based on :attr:`mode`. Appriopriate arguments \ (e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to \ construct the helper. Additional arguments for helper constructor \ can be provided either through :attr:`**kwargs`, or through \ :attr:`hparams["helper_train/infer"]["kwargs"]`. This means is used only when both :attr:`decoding_strategy` and \ :attr:`helper` are `None`. Example: .. code-block:: python h = { "helper_infer": { "type": "GumbelSoftmaxEmbeddingHelper", "kwargs": { "tau": 0.1 } } } embedder = WordEmbedder(vocab_size=data.vocab.size) decoder = BasicRNNDecoder(vocab_size=data.vocab.size, hparams=h) # Gumbel-softmax decoding output, _, _ = decoder( decoding_strategy=None, # Sets to None explicit embedding=embedder, start_tokens=[data.vocab.bos_token_id]*100, end_token=data.vocab.eos_token_id, max_decoding_length=60, mode=tf.estimator.ModeKeys.PREDICT) # PREDICT mode also shuts down dropout Args: decoding_strategy (str): A string specifying the decoding strategy. Different arguments are required based on the strategy. Ignored if :attr:`helper` is given. initial_state (optional): Initial state of decoding. If `None` (default), zero state is used. inputs (optional): Input tensors for teacher forcing decoding. Used when `decoding_strategy` is set to "train_greedy", or when `hparams`-configured helper is used. - If :attr:`embedding` is `None`, `inputs` is directly \ fed to the decoder. E.g., in `"train_greedy"` strategy, \ `inputs` must be a 3D Tensor of shape \ `[batch_size, max_time, emb_dim]` (or \ `[max_time, batch_size, emb_dim]` if `input_time_major`==True). - If `embedding` is given, `inputs` is used as index \ to look up embeddings and feed in the decoder. \ E.g., if `embedding` is an instance of \ :class:`~texar.tf.modules.WordEmbedder`, \ then :attr:`inputs` is usually a 2D int Tensor \ `[batch_size, max_time]` (or \ `[max_time, batch_size]` if `input_time_major`==True) \ containing the token indexes. sequence_length (optional): A 1D int Tensor containing the sequence length of :attr:`inputs`. Used when `decoding_strategy="train_greedy"` or `hparams`-configured helper is used. embedding (optional): Embedding used when: - "infer_greedy" or "infer_sample" `decoding_strategy` is \ used. This can be a callable or the `params` argument for \ :tf_main:`embedding_lookup <nn/embedding_lookup>`. \ If a callable, it can take a vector tensor of token `ids`, \ or take two arguments (`ids`, `times`), where `ids` \ is a vector tensor of token ids, and `times` is a vector tensor\ of time steps (i.e., position ids). The latter case can be used\ when attr:`embedding` is a combination of word embedding and\ position embedding. `embedding` is required in this case. - "train_greedy" `decoding_strategy` is used.\ This can be a callable or the `params` argument for \ :tf_main:`embedding_lookup <nn/embedding_lookup>`. \ If a callable, it can take :attr:`inputs` and returns \ the input embedding. `embedding` is optional in this case. start_tokens (optional): A int Tensor of shape `[batch_size]`, the start tokens. Used when `decoding_strategy="infer_greedy"` or `"infer_sample"`, or when the helper specified in `hparams` is used. Example: .. code-block:: python data = tx.data.MonoTextData(hparams) iterator = DataIterator(data) batch = iterator.get_next() bos_token_id = data.vocab.bos_token_id start_tokens=tf.ones_like(batch['length'])*bos_token_id end_token (optional): A int 0D Tensor, the token that marks end of decoding. Used when `decoding_strategy="infer_greedy"` or `"infer_sample"`, or when the helper specified in `hparams` is used. softmax_temperature (optional): A float 0D Tensor, value to divide the logits by before computing the softmax. Larger values (above 1.0) result in more random samples. Must > 0. If `None`, 1.0 is used. Used when `decoding_strategy="infer_sample"`. max_decoding_length: A int scalar Tensor indicating the maximum allowed number of decoding steps. If `None` (default), either `hparams["max_decoding_length_train"]` or `hparams["max_decoding_length_infer"]` is used according to :attr:`mode`. impute_finished (bool): If `True`, then states for batch entries which are marked as finished get copied through and the corresponding outputs get zeroed out. This causes some slowdown at each time step, but ensures that the final state and outputs have the correct values and that backprop ignores time steps that were marked as finished. output_time_major (bool): If `True`, outputs are returned as time major tensors. If `False` (default), outputs are returned as batch major tensors. input_time_major (optional): Whether the :attr:`inputs` tensor is time major. Used when `decoding_strategy="train_greedy"` or `hparams`-configured helper is used. helper (optional): An instance of :class:`texar.tf.modules.Helper` that defines the decoding strategy. If given, `decoding_strategy` and helper configs in :attr:`hparams` are ignored. mode (str, optional): A string taking value in :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`. If `TRAIN`, training related hyperparameters are used (e.g., `hparams['max_decoding_length_train']`), otherwise, inference related hyperparameters are used (e.g., `hparams['max_decoding_length_infer']`). If `None` (default), `TRAIN` mode is used. **kwargs: Other keyword arguments for constructing helpers defined by `hparams["helper_trainn"]` or `hparams["helper_infer"]`. Returns: `(outputs, final_state, sequence_lengths)`, where - **`outputs`**: an object containing the decoder output on all \ time steps. - **`final_state`**: is the cell state of the final time step. - **`sequence_lengths`**: is an int Tensor of shape `[batch_size]` \ containing the length of each sample. """ # Helper if helper is not None: pass elif decoding_strategy is not None: if decoding_strategy == "train_greedy": helper = rnn_decoder_helpers._get_training_helper( inputs, sequence_length, embedding, input_time_major) elif decoding_strategy == "infer_greedy": helper = tx_helper.GreedyEmbeddingHelper( embedding, start_tokens, end_token) elif decoding_strategy == "infer_sample": helper = tx_helper.SampleEmbeddingHelper( embedding, start_tokens, end_token, softmax_temperature) else: raise ValueError( "Unknown decoding strategy: {}".format(decoding_strategy)) else: if is_train_mode_py(mode): kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict()) helper_type = self._hparams.helper_train.type else: kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict()) helper_type = self._hparams.helper_infer.type kwargs_.update({ "inputs": inputs, "sequence_length": sequence_length, "time_major": input_time_major, "embedding": embedding, "start_tokens": start_tokens, "end_token": end_token, "softmax_temperature": softmax_temperature }) kwargs_.update(kwargs) helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_) self._helper = helper # Initial state if initial_state is not None: self._initial_state = initial_state else: self._initial_state = self.zero_state(batch_size=self.batch_size, dtype=tf.float32) # Maximum decoding length max_l = max_decoding_length if max_l is None: max_l_train = self._hparams.max_decoding_length_train if max_l_train is None: max_l_train = utils.MAX_SEQ_LENGTH max_l_infer = self._hparams.max_decoding_length_infer if max_l_infer is None: max_l_infer = utils.MAX_SEQ_LENGTH max_l = tf.cond(is_train_mode(mode), lambda: max_l_train, lambda: max_l_infer) self.max_decoding_length = max_l # Decode outputs, final_state, sequence_lengths = dynamic_decode( decoder=self, impute_finished=impute_finished, maximum_iterations=max_l, output_time_major=output_time_major) if not self._built: self._add_internal_trainable_variables() # Add trainable variables of `self._cell` which may be # constructed externally. self._add_trainable_variable( layers.get_rnn_cell_trainable_variables(self._cell)) if isinstance(self._output_layer, tf.layers.Layer): self._add_trainable_variable( self._output_layer.trainable_variables) # Add trainable variables of `self._beam_search_rnn_cell` which # may already be constructed and used. if self._beam_search_cell is not None: self._add_trainable_variable( self._beam_search_cell.trainable_variables) self._built = True return outputs, final_state, sequence_lengths