Beispiel #1
0
    def _init_decoder(self):
        """
            decoder cell.
            attention적용 시 결과가 좋지 않음.
        """
        with tf.variable_scope("Decoder") as scope:

            def output_fn(outputs):
                return tf.contrib.layers.linear(outputs,
                                                self.vocab_size,
                                                scope=scope)

            decoder_fn_train = seq2seq.simple_decoder_fn_train(
                encoder_state=self.encoder_state)
            decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                output_fn=output_fn,
                encoder_state=self.encoder_state,
                embeddings=self.embedding_matrix,
                start_of_sequence_id=self.EOS,
                end_of_sequence_id=self.EOS,
                maximum_length=self.len_max,
                num_decoder_symbols=self.vocab_size,
            )

            (self.decoder_outputs_train, self.decoder_state_train,
             self.decoder_context_state_train) = (seq2seq.dynamic_rnn_decoder(
                 cell=self.decoder_cell,
                 decoder_fn=decoder_fn_train,
                 inputs=self.decoder_train_inputs_embedded,
                 sequence_length=[
                     self.len_max for _ in range(self.batch_size)
                 ],
                 time_major=True,
                 scope=scope,
             ))

            self.decoder_logits_train = output_fn(self.decoder_outputs_train)
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train,
                axis=-1,
                name='decoder_prediction_train')

            scope.reuse_variables()

            (self.decoder_logits_inference, self.decoder_state_inference,
             self.decoder_context_state_inference) = (
                 seq2seq.dynamic_rnn_decoder(
                     cell=self.decoder_cell,
                     decoder_fn=decoder_fn_inference,
                     time_major=True,
                     scope=scope,
                 ))
            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference')
Beispiel #2
0
    def addDecoder(self):
        print('adding decoder...')
        cell = BasicRNNCell(2 * CONFIG.DIM_WordEmbedding)
        self.attention_states = self._encoder_outputs
        self.decoder_inputs_embedded = tf.nn.embedding_lookup(
            self.embedding, self.y_placeholder)
        # prepare attention:
        (attention_keys, attention_values, attention_score_fn,
         attention_construct_fn) = seq2seq.prepare_attention(
             attention_states=self.attention_states,
             attention_option='bahdanau',
             num_units=2 * CONFIG.DIM_WordEmbedding)

        if (self.is_training):
            # new Seq2seq train version
            self.check_op = tf.add_check_numerics_ops()
            decoder_fn_train = seq2seq.attention_decoder_fn_train(
                encoder_state=self._decoder_in_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                name='attention_decoder')
            (self.decoder_outputs_train, self.decoder_state_train,
             self.decoder_context_state_train) = seq2seq.dynamic_rnn_decoder(
                 cell=cell,
                 decoder_fn=decoder_fn_train,
                 inputs=self.decoder_inputs_embedded,
                 sequence_length=self.y_lens,
                 time_major=False)
            self.decoder_outputs = self.decoder_outputs_train

        else:

            # new Seq2seq version
            start_id = CONFIG.WORDS[CONFIG.STARTWORD]
            stop_id = CONFIG.WORDS[CONFIG.STOPWORD]
            decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                encoder_state=self._decoder_in_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                embeddings=self.embedding,
                start_of_sequence_id=start_id,
                end_of_sequence_id=stop_id,
                maximum_length=CONFIG.DIM_DECODER,
                num_decoder_symbols=CONFIG.DIM_VOCAB,
                output_fn=self.output_fn)
            (self.decoder_outputs_inference, self.decoder_state_inference,
             self.decoder_context_state_inference
             ) = seq2seq.dynamic_rnn_decoder(cell=cell,
                                             decoder_fn=decoder_fn_inference,
                                             time_major=False)
            self.decoder_outputs = self.decoder_outputs_inference
Beispiel #3
0
    def decode_train(cell, embeddings, encoder_state, targets, targets_length, scope='decoder', reuse=None):
        """
        Args:
            cell: An RNNCell object
            embeddings: An embedding matrix with shape
                (vocab_size, word_dim)
            encoder_state: A tensor that contains the encoder state;
                its shape should match that of cell.zero_state
            targets: A int32 tensor with shape (batch, max_len), which
                contains word indices; should start and end with
                the proper <BOS> and <EOS> symbol
            targets_length: A int32 tensor with shape (batch,), which
                contains the length of each sample in a batch
            scope: A VariableScope object of a string which indicates
                the scope
            reuse: A boolean value or None which specifies whether to
                reuse variables already defined in the scope

        Returns:
            decoder_outputs, which is a float32
            (batch, max_len, cell.output_size) tensor that contains
            the cell's hidden state per time step
        """

        with tf.variable_scope(scope, initializer=tf.orthogonal_initializer(), reuse=reuse):
            decoder_fn = seq2seq.simple_decoder_fn_train(encoder_state=encoder_state)
            targets_embed = tf.nn.embedding_lookup(params=embeddings, ids=targets)
            decoder_outputs, _, _ = seq2seq.dynamic_rnn_decoder(cell=cell, decoder_fn=decoder_fn, inputs=targets_embed,
                sequence_length=targets_length, time_major=False, scope='rnn')
        return decoder_outputs
def decode_training_set(encoder_state, decoder_cell, decoder_embedded_input,
                        sequence_length, decoding_scope, output_function,
                        keep_prob, batch_size):
    """Decode the training set."""
    attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
    (attention_keys, attentions_values, attention_score_function,
     attention_construct_function) = prepare_attention(
         attention_states,
         attention_option='bahdanau',
         num_units=decoder_cell.output_size)
    training_decoder_function = attention_decoder_fn_train(
        encoder_state[0],
        attention_keys,
        attentions_values,
        attention_score_function,
        attention_construct_function,
        name='attn_dec_train')
    decoder_output, _, _ = dynamic_rnn_decoder(decoder_cell,
                                               training_decoder_function,
                                               decoder_embedded_input,
                                               sequence_length,
                                               scope=decoding_scope)
    decoder_output_dropout = tf.nn.dropout(decoder_output, keep_prob)

    return output_function(decoder_output_dropout)
def decode_test_set(encoder_state, decoder_cell, decoder_embeddings_matrix,
                    sos_id, eos_id, maximum_length, num_words, sequence_length,
                    decoding_scope, output_function, keep_prob, batch_size):
    """Decode the validation set."""
    attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
    (attention_keys, attentions_values, attention_score_function,
     attention_construct_function) = prepare_attention(
         attention_states,
         attention_option='bahdanau',
         num_units=decoder_cell.output_size)
    test_decoder_function = attention_decoder_fn_inference(
        output_function,
        encoder_state[0],
        attention_keys,
        attentions_values,
        attention_score_function,
        attention_construct_function,
        decoder_embeddings_matrix,
        sos_id,
        eos_id,
        maximum_length,
        num_words,
        name='attn_dec_inf')
    test_predictions, _, _ = dynamic_rnn_decoder(decoder_cell,
                                                 test_decoder_function,
                                                 scope=decoding_scope)

    return test_predictions
Beispiel #6
0
def decoder_test_set(encoder_state, decoder_cell, batch_size, decoder_scope,
                     keep_prob, decoder_embedding_matrix, sequence_length,
                     decoder_output_function, sos_id, eos_id, max_length,
                     num_symbols):
    attention_state = tf.zeros([batch_size, 1, decoder_cell.output_size])
    attention_key, attention_value, attention_score_function, attention_construct_function = seq2seq.prepare_attention(
        attention_states=attention_state,
        attention_option="nabdanau",
        num_units=decoder_cell.output_size)
    decoder_test_output = seq2seq.attention_decoder_fn_inference(
        output_fn=decoder_output_function,
        encoder_state=encoder_state,
        attention_keys=attention_key,
        attention_values=attention_value,
        attention_score_fn=attention_score_function,
        attention_construct_fn=attention_construct_function,
        embeddings=decoder_embedding_matrix,
        start_of_sequence_id=sos_id,
        end_of_sequence_id=eos_id,
        maximum_length=max_length,
        num_decoder_symbols=num_symbols,
        dtype=tf.float32,
        name="attn_dec_inf")
    decoder_output, _, _ = seq2seq.dynamic_rnn_decoder(decoder_cell,
                                                       decoder_test_output,
                                                       scope=decoder_scope)

    return decoder_output
Beispiel #7
0
def decode_training_set(encoder_state, decoder_cell, decoder_embedded_input,
                        sequence_length, decoding_scope, output_function,
                        keep_prob, batch_size):
    attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])

    attention_keys, attention_values, attention_score_function, attention_construct_function \
        = seq2seq.prepare_attention(attention_states,
                                    attention_option="bahdanau",
                                    num_units=decoder_cell.output_size)

    training_decoder_function = seq2seq.attention_decoder_fn_train(
        encoder_state[0],
        attention_keys,
        attention_values,
        attention_score_function,
        attention_construct_function,
        name="attn_dec_train")

    decoder_output, _, _ = seq2seq.dynamic_rnn_decoder(
        decoder_cell,
        training_decoder_function,
        decoder_embedded_input,
        sequence_length,
        scope=decoding_scope)

    decoder_output_dropout = tf.nn.dropout(decoder_output, keep_prob)
    return output_function(decoder_output_dropout)
def decode_validation_set(encoder_state, decoder_cell,
                          decoder_embeddings_matrix, sos_id, eos_id,
                          max_length, num_words, decoding_scope,
                          output_function, keep_prob, batch_size):
    attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
    attention_keys, attention_values, attention_score_fx, attention_construct_fx = prepare_attention(
        attention_states,
        attention_option='bahdanau',
        num_units=decoder_cell.output_size)
    validate_decoder_fx = attention_decoder_fn_inference(
        output_fn=output_function,
        encoder_state=encoder_state[0],
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_score_fn=attention_score_fx,
        attention_construct_fn=attention_construct_fx,
        embeddings=decoder_embeddings_matrix,
        start_of_sequence_id=sos_id,
        end_of_sequence_id=eos_id,
        maximum_length=max_length,
        num_decoder_symbols=num_words,
        name='attn_dec_inf')
    predictions, _, _ = dynamic_rnn_decoder(cell=decoder_cell,
                                            decoder_fn=validate_decoder_fx,
                                            scope=decoding_scope)
    return predictions
Beispiel #9
0
def decode_test_set(encoder_state, decoder_cell, decoder_embeddings_matrix,
                    sos_id, eos_id, maximum_length, num_words, decoding_scope,
                    output_function, keep_prob, batch_size):
    attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])

    attention_keys, attention_values, attention_score_function, attention_construct_function \
        = seq2seq.prepare_attention(attention_states,
                                    attention_option="bahdanau",
                                    num_units=decoder_cell.output_size)

    test_decoder_function = seq2seq.attention_decoder_fn_inference(
        output_function,
        encoder_state[0],
        attention_keys,
        attention_values,
        attention_score_function,
        attention_construct_function,
        decoder_embeddings_matrix,
        sos_id,
        eos_id,
        maximum_length,
        num_words,
        name="attn_dec_inf")

    test_predictions, _, _ = seq2seq.dynamic_rnn_decoder(decoder_cell,
                                                         test_decoder_function,
                                                         scope=decoding_scope)
    return test_predictions
    def decode_topk(self):
        with self.graph.as_default():
            with tf.variable_scope("Decoder") as scope:
                tf.get_variable_scope().reuse_variables()

                def output_fn(outputs):
                    return tf.contrib.layers.linear(outputs,
                                                    self.vocab_size,
                                                    scope=scope)

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_final_state,
                    attention_keys=self.attention_keys,
                    attention_values=self.attention_values,
                    attention_score_fn=self.attention_score_fn,
                    attention_construct_fn=self.attention_construct_fn,
                    embeddings=self.embed,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=23,  # max_twee_len + 3,
                    num_decoder_symbols=self.vocab_size)
                (self.decoder_logits_inference_beam,
                 self.decoder_state_inference_beam,
                 self.decoder_context_state_inference_beam) = (
                     seq2seq.dynamic_rnn_decoder(
                         cell=self.decoder_cell,
                         decoder_fn=decoder_fn_inference,
                         time_major=True,
                         scope=scope))

        return self.decoder_logits_inference_beam, self.decoder_state_inference_beam
    def TweetInitDecoder(self, input_state):
        with self.graph.as_default():
            with tf.variable_scope("TweetInitDecoder") as scope:

                def output_fn(outputs):
                    return tf.contrib.layers.linear(outputs,
                                                    self.vocab_size,
                                                    scope=scope)

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=input_state,  # self.encoder_final_state
                    attention_keys=self.attention_keys,
                    attention_values=self.attention_values,
                    attention_score_fn=self.attention_score_fn,
                    attention_construct_fn=self.attention_construct_fn,
                    embeddings=self.embed,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=23,  # max_twee_len + 3,
                    num_decoder_symbols=self.vocab_size)
                (self.tidecoder_logits_inference,
                 self.tidecoder_state_inference,
                 self.tidecoder_context_state_inference) = (
                     seq2seq.dynamic_rnn_decoder(
                         cell=self.decoder_cell,
                         decoder_fn=decoder_fn_inference,
                         time_major=True,
                         scope=scope))

                self.tidecoder_prediction_inference = tf.argmax(
                    self.tidecoder_logits_inference,
                    axis=-1,
                    name='TIdecoder_prediction_inference')
        return self.tidecoder_prediction_inference
Beispiel #12
0
def recurrent_layer(tensor, cell=None, hidden_dims=128, sequence_length=None, decoder_fn=None, 
                    activation=tf.nn.tanh, initializer=tf.orthogonal_initializer(), initial_state=None, 
                    keep_prob=1.0,
                    return_final_state=False, return_next_cell_input=True, **opts):
    if cell is None:
        cell = tf.contrib.rnn.BasicRNNCell(hidden_dims, activation=activation)
        # cell = tf.contrib.rnn.LSTMCell(hidden_dims, activation=activation)

    if keep_prob < 1.0:
        keep_prob = _global_keep_prob(keep_prob)
        cell = tf.contrib.rnn.DropoutWrapper(cell, keep_prob, keep_prob)

    if opts.get("name"):
        tf.add_to_collection(opts.get("name"), cell)

    if decoder_fn is None:
        outputs, final_state = tf.nn.dynamic_rnn(cell, tensor, 
            sequence_length=sequence_length, initial_state=initial_state, dtype=tf.float32)
        final_context_state = None
    else:
        # TODO: turn off sequence_length?
        outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
            cell, decoder_fn, inputs=None, sequence_length=sequence_length)

    if return_final_state:
        return final_state
    else:
        return outputs
 def decoder_train_layer(self, encoded_state, token_embedding):
     with tf.variable_scope('decoder'):
         dynamic_fn_train = seq2seq.simple_decoder_fn_train(encoded_state)
         decoder_outputs, state, context = seq2seq.dynamic_rnn_decoder(
             self.decoder_cell, dynamic_fn_train, token_embedding,
             self.sequence_lengths)
     return decoder_outputs
 def decoder(self,encoder_state,inputs=None,is_train=True):
     '''
     解码器
     '''
     with tf.variable_scope("decoder") as scope:
         if is_train is True:
             decoder_fn=seq2seq.simple_decoder_fn_train(encoder_state)
             outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=inputs,sequence_length=self.seq_len,time_major=False,scope=scope)
         else:
             tf.get_variable_scope().reuse_variables()
             #解码时,通过decoder embedding和decoder bias计算每个词的概率
             output_fn=lambda x:tf.nn.softmax(tf.matmul(x,self.dec_embedding,transpose_b=True)+self.dec_bias)
             decoder_fn=seq2seq.simple_decoder_fn_inference(output_fn=output_fn,encoder_state=encoder_state,embeddings=self.embedding,
                                                            start_of_sequence_id=0,end_of_sequence_id=0,maximum_length=self.subject_len,
                                                            num_decoder_symbols=self.vocab_size,dtype=tf.int32)
             outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=None,sequence_length=self.seq_len,time_major=False,scope=scope)
         
     return outputs,final_state,final_context_state
Beispiel #15
0
 def decoder(self):
     cell = self._rnn_cell()
     initial_state = cell.zero_state(self.model.batch_size,
                                     dtype=tf.float32)
     decoder_fn = seq2seq.simple_decoder_fn_train(initial_state)
     output = seq2seq.dynamic_rnn_decoder(cell,
                                          decoder_fn,
                                          self.convolution,
                                          self.sequence_lengths,
                                          time_major=True)
     return output[0]
 def decoder_inference_layer(self, encoded_state, token_embedding):
     num_tokens = self.data.token_vocab_size
     with tf.variable_scope('decoder') as scope:
         scope.reuse_variables()
         dynamic_fn_test = seq2seq.simple_decoder_fn_inference(
             None, encoded_state, self.token_embedding_matrix,
             self.data.GO_CODE, self.data.STOP_CODE,
             self.data.token_sequence_length - 2, num_tokens)
         decoder_outputs, state, context = seq2seq.dynamic_rnn_decoder(
             self.decoder_cell, dynamic_fn_test)
     return decoder_outputs
Beispiel #17
0
  def encoder_decoder(self):

    with tf.variable_scope("seq2seq") as scope:
      # Encoder
      enc_cell = tf.contrib.rnn.GRUCell(self.n_cells)
      enc_cell = tf.contrib.rnn.DropoutWrapper(cell=enc_cell,
        output_keep_prob=self.dropout_rate)  # Important: RNN version of Dropout!
      enc_init = tf.get_variable('init_state', [self.batch_size, self.n_cells],
         initializer=self.initializer())

      if not toy:
        questions_batch_embedding = tf.nn.embedding_lookup(self.embedding_matrix, self.question_ids)

      with tf.variable_scope("encoder"):
        if toy:
          _, enc_state = tf.nn.dynamic_rnn(enc_cell,
              self.input_placeholder, sequence_length=self.enc_seq_len,
              initial_state=enc_init, dtype=tf.float32)
        else:
          _, enc_state = tf.nn.dynamic_rnn(enc_cell,
              questions_batch_embedding, sequence_length=self.enc_seq_len,
              initial_state=enc_init, dtype=tf.float32)

      # Decoder
      if toy:
        dec_stage = "inference" if self.labels is None else "training"
      else:
        dec_stage = "inference" if self.answer_ids is None else "training" # Weird line...
      dec_cell = tf.contrib.rnn.GRUCell(self.n_cells)
      dec_function = self.decoder_components(dec_stage, "function", enc_state)
      dec_inputs = self.decoder_components(dec_stage, "inputs", None)
      dec_seq_len = self.decoder_components(dec_stage, "sequence_length", None)
      with tf.variable_scope("decoder"):
        pred, _, _ = seq2seq.dynamic_rnn_decoder(dec_cell, dec_function,
          inputs=dec_inputs, sequence_length=dec_seq_len)
      with tf.variable_scope("decoder", reuse=True):
        pred, dec_state, _ = seq2seq.dynamic_rnn_decoder(dec_cell, dec_function,
          inputs=dec_inputs, sequence_length=dec_seq_len)

    return pred, dec_stage # not state
def decode_inference(cell,
                     embeddings,
                     encoder_state,
                     output_fn,
                     vocab_size,
                     bos_id,
                     eos_id,
                     max_length,
                     scope='decoder',
                     reuse=None):
    """
    Args:
        cell: An RNNCell object
        embeddings: An embedding matrix with shape
            (vocab_size, word_dim)
        encoder_state: A tensor that contains the encoder state;
            its shape should match that of cell.zero_state
        output_fn: A function that projects a vector with length
            cell.output_size into a vector with length vocab_size;
            please beware of the scope, since it will be called inside
            'scope/rnn' scope
        vocab_size: The size of a vocabulary set
        bos_id: The ID of the beginning-of-sentence symbol
        eos_id: The ID of the end-of-sentence symbol
        max_length: The maximum length of a generated sentence;
            it stops generating words when this number of words are
            generated and <EOS> is not appeared till then
        scope: A VariableScope object of a string which indicates
            the scope
        reuse: A boolean value or None which specifies whether to
            reuse variables already defined in the scope

    Returns:
        generated, which is a float32 (batch, <=max_len)
        tensor that contains IDs of generated words
    """

    with tf.variable_scope(scope,
                           initializer=tf.orthogonal_initializer(),
                           reuse=reuse):
        decoder_fn = seq2seq.simple_decoder_fn_inference(
            output_fn=output_fn,
            encoder_state=encoder_state,
            embeddings=embeddings,
            start_of_sequence_id=bos_id,
            end_of_sequence_id=eos_id,
            maximum_length=max_length,
            num_decoder_symbols=vocab_size)
        generated_logits, _, _ = seq2seq.dynamic_rnn_decoder(
            cell=cell, decoder_fn=decoder_fn, time_major=False, scope='rnn')
    generated = tf.argmax(generated_logits, axis=2)
    return generated
 def decoder(self,encoder_state,attention_states,inputs=None,is_train=True):
     '''
     基于attention的解码器
     1.调用seq2seq.prepare_attention 生成attention的keys/values/functions
     2.训练时,定义dynamic_rnn_decoder用到的attention_decoder_fn_train
     3.预测时,定义dynamic_rnn_decoder用到的attention_decoder_fn_inference
     4.使用以上步骤得到的参数,调用seq2seq.dynamic_rnn_decoder函数
     '''
     with tf.variable_scope("decoder") as scope:
         #1. prepare attention
         keys,values,score_fn,construct_fn=seq2seq.prepare_attention(attention_states=attention_states,attention_option="luong",num_units=self.emb_dim)
         if is_train is True:
             decoder_fn=seq2seq.attention_decoder_fn_train(encoder_state,attention_keys=keys,attention_values=values,attention_score_fn=score_fn,attention_construct_fn=construct_fn,)
             outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=inputs,sequence_length=self.seq_len,time_major=False,scope=scope)
         else:
             tf.get_variable_scope().reuse_variables()
             #解码时,通过decoder embedding和decoder bias计算每个词的概率
             output_fn=lambda x:tf.nn.softmax(tf.matmul(x,self.dec_embedding,transpose_b=True)+self.dec_bias)
             decoder_fn=seq2seq.attention_decoder_fn_inference(output_fn=output_fn,encoder_state=encoder_state,attention_keys=keys,attention_values=values,
                                                               attention_score_fn=score_fn,attention_construct_fn=construct_fn,embeddings=self.dec_embedding,
                                                               start_of_sequence_id=0,end_of_sequence_id=1,maximum_length=5,num_decoder_symbols=self.vocab_size)
             outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=None,sequence_length=self.seq_len,time_major=False,scope=scope)
         
     return outputs,final_state,final_context_state
Beispiel #20
0
def recurrent_layer(tensor,
                    cell=None,
                    hidden_dims=128,
                    sequence_length=None,
                    decoder_fn=None,
                    activation=tf.nn.tanh,
                    initializer=tf.orthogonal_initializer(),
                    initial_state=None,
                    keep_prob=1.0,
                    return_final_state=False,
                    return_next_cell_input=True,
                    **opts):
    if cell is None:
        cell = tf.contrib.rnn.BasicRNNCell(hidden_dims, activation=activation)
        # cell = tf.contrib.rnn.LSTMCell(hidden_dims, activation=activation)

    if keep_prob < 1.0:
        keep_prob = _global_keep_prob(keep_prob)
        cell = tf.contrib.rnn.DropoutWrapper(cell, keep_prob, keep_prob)

    if opts.get("name"):
        tf.add_to_collection(opts.get("name"), cell)

    if decoder_fn is None:
        outputs, final_state = tf.nn.dynamic_rnn(
            cell,
            tensor,
            sequence_length=sequence_length,
            initial_state=initial_state,
            dtype=tf.float32)
        final_context_state = None
    else:
        # TODO: turn off sequence_length?
        outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
            cell, decoder_fn, inputs=None, sequence_length=sequence_length)

    if return_final_state:
        return final_state
    else:
        return outputs
Beispiel #21
0
    def __init__(self, para):
        para.fac = int(para.bidirectional) + 1
        self._para = para
        if para.rnn_type == 0:  #basic rnn

            def unit_cell(fac):
                return tf.contrib.rnn.BasicRNNCell(para.hidden_size * fac)
        elif para.rnn_type == 1:  #basic LSTM

            def unit_cell(fac):
                return tf.contrib.rnn.BasicLSTMCell(para.hidden_size * fac)
        elif para.rnn_type == 2:  #full LSTM

            def unit_cell(fac):
                return tf.contrib.rnn.LSTMCell(para.hidden_size * fac,
                                               use_peepholes=True)
        elif para.rnn_type == 3:  #GRU

            def unit_cell(fac):
                return tf.contrib.rnn.GRUCell(para.hidden_size * fac)

        rnn_cell = unit_cell
        #dropout layer
        if not self.is_test() and para.keep_prob < 1:

            def rnn_cell(fac):
                return tf.contrib.rnn.DropoutWrapper(
                    unit_cell(fac), output_keep_prob=para.keep_prob)

        #multi-layer rnn
        encoder_cell =\
          tf.contrib.rnn.MultiRNNCell([rnn_cell(1) for _ in range(para.layer_num)])
        if para.bidirectional:
            b_encoder_cell = tf.contrib.rnn.MultiRNNCell(
                [rnn_cell(1) for _ in range(para.layer_num)])
        #feed in data in batches
        if not self.is_test():
            video, caption, v_len, c_len = self.get_single_example(para)
            videos, captions, v_lens, c_lens =\
                tf.train.batch([video, caption, v_len, c_len],
                               batch_size=para.batch_size, dynamic_pad=True)
            #sparse tensor cannot be sliced
            targets = tf.sparse_tensor_to_dense(captions)
            decoder_in = targets[:, :-1]
            decoder_out = targets[:, 1:]
            c_lens = tf.to_int32(c_lens)
        else:
            video, v_len = self.get_single_example(para)
            videos, v_lens =\
                tf.train.batch([video, v_len],
                               batch_size=para.batch_size, dynamic_pad=True)
        v_lens = tf.to_int32(v_lens)
        with tf.variable_scope('embedding'):
            if para.use_pretrained:
                W_E =\
                  tf.Variable(tf.constant(0., shape= [para.vocab_size, para.w_emb_dim]),
                              trainable=False, name='W_E')
                self._embedding = tf.placeholder(
                    tf.float32, [para.vocab_size, para.w_emb_dim])
                self._embed_init = W_E.assign(self._embedding)
            else:
                W_E = tf.get_variable('W_E', [para.vocab_size, para.w_emb_dim],
                                      dtype=tf.float32)

        if not self.is_test():
            decoder_in_embed = tf.nn.embedding_lookup(W_E, decoder_in)

        if para.v_emb_dim < para.video_dim:
            inputs = fully_connected(videos, para.v_emb_dim)
        else:
            inputs = videos

        if not self.is_test() and para.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, para.keep_prob)

        if not para.bidirectional:
            encoder_outputs, encoder_states =\
              tf.nn.dynamic_rnn(encoder_cell, inputs,
                                sequence_length=v_lens, dtype=tf.float32)
        else:
            encoder_outputs, encoder_states =\
              tf.nn.bidirectional_dynamic_rnn(encoder_cell, b_encoder_cell,
                                              inputs, sequence_length=v_lens,
                                              dtype=tf.float32)
            encoder_states = tuple([
                LSTMStateTuple(tf.concat([f_st.c, f_st.c], 1),
                               tf.concat([b_st.h, b_st.h], 1))
                for f_st, b_st in zip(encoder_states[0], encoder_states[1])
            ])
            encoder_outputs = tf.concat(
                [encoder_outputs[0], encoder_outputs[1]], 2)

        with tf.variable_scope('softmax'):
            softmax_w = tf.get_variable(
                'w', [para.hidden_size * para.fac, para.vocab_size],
                dtype=tf.float32)
            softmax_b = tf.get_variable('b', [para.vocab_size],
                                        dtype=tf.float32)
            output_fn = lambda output: tf.nn.xw_plus_b(output, softmax_w,
                                                       softmax_b)

        decoder_cell =\
          tf.contrib.rnn.MultiRNNCell([rnn_cell(para.fac)
                                       for _ in range(para.layer_num)])
        if para.attention > 0:
            at_option = ["bahdanau", "luong"][para.attention - 1]
            at_keys, at_vals, at_score, at_cons =\
              seq2seq.prepare_attention(attention_states=encoder_outputs,
                                        attention_option=at_option,
                                        num_units=para.hidden_size*para.fac)
        if self.is_test():
            if para.attention:
                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=encoder_states,
                    attention_keys=at_keys,
                    attention_values=at_vals,
                    attention_score_fn=at_score,
                    attention_construct_fn=at_cons,
                    embeddings=W_E,
                    start_of_sequence_id=2,
                    end_of_sequence_id=3,
                    maximum_length=20,
                    num_decoder_symbols=para.vocab_size)
            else:
                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=encoder_states,
                    embeddings=W_E,
                    start_of_sequence_id=2,
                    end_of_sequence_id=3,
                    maximum_length=20,
                    num_decoder_symbols=para.vocab_size)
            with tf.variable_scope('decode', reuse=True):
                decoder_logits, _, _ =\
                  seq2seq.dynamic_rnn_decoder(cell=decoder_cell,
                                              decoder_fn=decoder_fn_inference)
            self._prob = tf.nn.softmax(decoder_logits)

        else:
            global_step = tf.contrib.framework.get_or_create_global_step()

            def decoder_fn_train(time, cell_state, cell_input, cell_output,
                                 context):
                if para.scheduled_sampling and cell_output is not None:
                    epsilon = tf.cast(
                        1 - (global_step //
                             (para.tot_train_num // para.batch_size + 1) /
                             para.max_epoch), tf.float32)
                    cell_input = tf.cond(
                        tf.less(tf.random_uniform([1]), epsilon)[0],
                        lambda: cell_input, lambda: tf.gather(
                            W_E, tf.argmax(output_fn(cell_output), 1)))
                if cell_state is None:
                    cell_state = encoder_states
                    if para.attention:
                        attention = _init_attention(encoder_states)
                else:
                    if para.attention:
                        cell_output = attention = at_cons(
                            cell_output, at_keys, at_vals)
                if para.attention:
                    nxt_cell_input = tf.concat([cell_input, attention], 1)
                else:
                    nxt_cell_input = cell_input
                return None, encoder_states, nxt_cell_input, cell_output, context

            with tf.variable_scope('decode', reuse=None):
                (decoder_outputs, _, _) =\
                  seq2seq.dynamic_rnn_decoder(cell=decoder_cell,
                                              decoder_fn=decoder_fn_train,
                                              inputs=decoder_in_embed,
                                              sequence_length=c_lens)
            decoder_outputs =\
              tf.reshape(decoder_outputs, [-1, para.hidden_size*para.fac])
            c_len_max = tf.reduce_max(c_lens)

            logits = output_fn(decoder_outputs)
            logits = tf.reshape(logits,
                                [para.batch_size, c_len_max, para.vocab_size])
            self._prob = tf.nn.softmax(logits)

            msk = tf.sequence_mask(c_lens, dtype=tf.float32)
            loss = sequence_loss(logits, decoder_out, msk)

            self._cost = cost = tf.reduce_mean(loss)

            #if validation or testing, exit here
            if self.is_valid(): return

            #clip global gradient norm
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
                                              para.max_grad_norm)
            optimizer = optimizers[para.optimizer](para.learning_rate)
            self._eval = optimizer.apply_gradients(zip(grads, tvars),
                                                   global_step=global_step)
    def _init_decoder(self):
        with tf.variable_scope("Decoder") as scope:

            def output_fn(outputs):
                return tf.contrib.layers.linear(outputs,
                                                self.vocab_size,
                                                scope=scope)

            if not self.attention:
                decoder_fn_train = seq2seq.simple_decoder_fn_train(
                    encoder_state=self.encoder_state)
                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=data_utils.GO_ID,
                    end_of_sequence_id=data_utils.EOS_ID,
                    maximum_length=FLAGS.max_inf_target_len,
                    num_decoder_symbols=self.vocab_size,
                )
            else:

                # attention_states: size [batch_size, max_time, num_units]
                attention_states = tf.transpose(self.encoder_outputs,
                                                [1, 0, 2])

                #attention_states = tf.zeros([batch_size, 1, self.decoder_hidden_units])

                (attention_keys, attention_values, attention_score_fn,
                 attention_construct_fn) = seq2seq.prepare_attention(
                     attention_states=attention_states,
                     attention_option="bahdanau",
                     num_units=self.decoder_hidden_units,
                 )

                decoder_fn_train = seq2seq.attention_decoder_fn_train(
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    name='attention_decoder')

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=data_utils.GO_ID,
                    end_of_sequence_id=data_utils.EOS_ID,
                    maximum_length=FLAGS.max_inf_target_len,
                    num_decoder_symbols=self.vocab_size,
                )

            (self.decoder_outputs_train, self.decoder_state_train,
             self.decoder_context_state_train) = (seq2seq.dynamic_rnn_decoder(
                 cell=self.decoder_cell,
                 decoder_fn=decoder_fn_train,
                 inputs=self.decoder_train_inputs_embedded,
                 sequence_length=self.decoder_train_length,
                 time_major=True,
                 scope=scope,
             ))

            self.decoder_outputs_train = tf.nn.dropout(
                self.decoder_outputs_train, _keep_prob)

            self.decoder_logits_train = output_fn(self.decoder_outputs_train)

            # reusing the scope of training to use the same variables for inference
            scope.reuse_variables()

            (self.decoder_logits_inference, self.decoder_state_inference,
             self.decoder_context_state_inference) = (
                 seq2seq.dynamic_rnn_decoder(
                     cell=self.decoder_cell,
                     decoder_fn=decoder_fn_inference,
                     time_major=True,
                     scope=scope,
                 ))

            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference')
Beispiel #23
0
    return _decoder_fn


if __name__ == "__main__":
    import numpy as np
    from tensorflow.contrib import seq2seq
    from tensorflow.contrib import rnn

    sequence_length = 25
    vocab_size = 100
    hidden_dims = 10
    batch_size = 32

    cell = rnn.BasicRNNCell(hidden_dims)
    encoder_state = tf.constant(np.random.randn(batch_size, hidden_dims), dtype=np.float32)
    embeddings = tf.constant(np.random.randn(vocab_size, hidden_dims), dtype=np.float32)
    output_W = tf.transpose(embeddings)  # -- tied embeddings
    output_b = tf.constant(np.random.randn(vocab_size), dtype=np.float32)
    output_projections = (output_W, output_b)
    maximum_length=tf.reduce_max(sequence_length) + 3

    decoder_fn = gumbel_decoder_fn(encoder_state, embeddings, output_projections, maximum_length)

    outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
            cell, decoder_fn, inputs=None, sequence_length=sequence_length)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        outputs_ = sess.run(outputs)        
        print("outputs shape", outputs_.shape)
Beispiel #24
0
    def _init_decoder(self, output_projection):
        '''
        Decoder phase
        '''
        with tf.variable_scope("Decoder") as scope:
            self.decoder_inputs_embedded = tf.nn.embedding_lookup(self.lookup_matrix, self.decoder_inputs)

            (attention_keys,
             attention_values,
             attention_score_fn,
             attention_construct_fn) = seq2seq.prepare_attention(
                attention_states=self.attention_states,
                attention_option="bahdanau",
                num_units=self.args.h_units_decoder,
            )

            # attention is added
            decoder_fn_train = seq2seq.attention_decoder_fn_train(
                encoder_state=self.encoder_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                name='attention_decoder'
            )

            decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                output_fn=output_projection,
                encoder_state=self.encoder_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                embeddings=self.lookup_matrix,
                start_of_sequence_id=self.textData.goToken,
                end_of_sequence_id=self.textData.eosToken,
                maximum_length=tf.reduce_max(self.decoder_targets_length),
                num_decoder_symbols=self.textData.getVocabularySize(),
            )

            # Check back here later...the hidden size of decoder_cell has to be in the same size of embedding layer?
            # !!!
            # decoder_outputs_train.shape = (batch_size, n_words, hidden_size)
            (
            self.decoder_outputs_train, decoder_state_train, decoder_context_state_train) = seq2seq.dynamic_rnn_decoder(
                cell=self.decoder_cell,
                decoder_fn=decoder_fn_train,
                inputs=self.decoder_inputs_embedded,
                sequence_length=self.decoder_targets_length,
                time_major=False,
                scope=scope
            )

            # self.decoder_logits_train = output_projection(self.decoder_outputs_train)
            # self.decoder_logits_train_trans = tf.reshape(self.decoder_outputs_train, [1,0,2])

            self.decoder_logits_train = tf.transpose(tf.map_fn(output_projection,
                                                               tf.transpose(self.decoder_outputs_train, [1, 0, 2])),
                                                     [1, 0, 2])

            self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1,
                                                      name='decoder_prediction_train')

            # for both training and inference
            scope.reuse_variables()

            (self.decoder_logits_inference,
             decoder_state_inference,
             decoder_context_state_inference) = (
                seq2seq.dynamic_rnn_decoder(
                    cell=self.decoder_cell,
                    decoder_fn=decoder_fn_inference,
                    time_major=False,
                    scope=scope
                )
            )
    def decoder_adv(self, max_twee_len):
        with self.graph.as_default():
            with tf.variable_scope("Decoder") as scope:
                self.decoder_length = max_twee_len + 3

                def output_fn(outputs):
                    return tf.contrib.layers.linear(outputs,
                                                    self.vocab_size,
                                                    scope=scope)

                # self.decoder_cell = LSTMCell(self.decoder_hidden_nodes)
                self.decoder_cell = GRUCell(self.decoder_hidden_nodes)
                if not self.attention:
                    decoder_train = seq2seq.simple_decoder_fn_train(
                        encoder_state=self.encoder_final_state)
                    decoder_inference = seq2seq.simple_decoder_fn_inference(
                        output_fn=output_fn,
                        encoder_state=self.encoder_final_state,
                        embeddings=self.embed,
                        start_of_sequence_id=self.EOS,
                        end_of_sequence_id=self.EOS,
                        maximum_length=self.decoder_length,
                        num_decoder_symbols=self.vocab_size)
                else:
                    # attention_states: size [batch_size, max_time, num_units]
                    self.attention_states = tf.transpose(
                        self.encoder_output, [1, 0, 2])
                    (self.attention_keys, self.attention_values, self.attention_score_fn, self.attention_construct_fn) = \
                        seq2seq.prepare_attention(attention_states = self.attention_states, attention_option = "bahdanau",
                                                  num_units = self.decoder_hidden_nodes)

                    decoder_fn_train = seq2seq.attention_decoder_fn_train(
                        encoder_state=self.encoder_final_state,
                        attention_keys=self.attention_keys,
                        attention_values=self.attention_values,
                        attention_score_fn=self.attention_score_fn,
                        attention_construct_fn=self.attention_construct_fn,
                        name="attention_decoder")

                    decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                        output_fn=output_fn,
                        encoder_state=self.encoder_final_state,
                        attention_keys=self.attention_keys,
                        attention_values=self.attention_values,
                        attention_score_fn=self.attention_score_fn,
                        attention_construct_fn=self.attention_construct_fn,
                        embeddings=self.embed,
                        start_of_sequence_id=self.EOS,
                        end_of_sequence_id=self.EOS,
                        maximum_length=
                        23,  #max_twee_len + 3,  #tf.reduce_max(self.de_out_len) + 3,
                        num_decoder_symbols=self.vocab_size)
                    self.decoder_train_inputs_embedded = tf.nn.embedding_lookup(
                        self.embed, self.decoder_train_input)
                    (self.decoder_outputs_train, self.decoder_state_train,
                     self.decoder_context_state_train) = (
                         seq2seq.dynamic_rnn_decoder(
                             cell=self.decoder_cell,
                             decoder_fn=decoder_fn_train,
                             inputs=self.decoder_train_inputs_embedded,
                             sequence_length=self.decoder_train_length,
                             time_major=True,
                             scope=scope))

                    self.decoder_logits_train = output_fn(
                        self.decoder_outputs_train)
                    self.decoder_prediction_train = tf.argmax(
                        self.decoder_logits_train,
                        axis=-1,
                        name='decoder_prediction_train')

                    scope.reuse_variables()
                    (self.decoder_logits_inference,
                     self.decoder_state_inference,
                     self.decoder_context_state_inference) = (
                         seq2seq.dynamic_rnn_decoder(
                             cell=self.decoder_cell,
                             decoder_fn=decoder_fn_inference,
                             time_major=True,
                             scope=scope))
                    self.decoder_prediction_inference = tf.argmax(
                        self.decoder_logits_inference,
                        axis=-1,
                        name='decoder_prediction_inference')



        return self.de_out, self.de_out_len, self.title_out, self.first_out, self.decoder_logits_train, \
               self.decoder_prediction_train, self.loss_weights, self.decoder_train_targets, \
               self.decoder_train_title, self.decoder_train_first, self.decoder_prediction_inference
Beispiel #26
0
    def _init_decoder(self):
        with tf.variable_scope("Decoder") as scope:

            def output_fn(outputs):
                return tf.contrib.layers.linear(
                    outputs, self.vocab_size, scope=scope
                )  #this is for calculatng outputs. In a greedy way

            if not self.attention:
                decoder_fn_train = seq2seq.simple_decoder_fn_train(
                    encoder_state=self.encoder_state
                )  #This is the training  function that we used in training  dynamic_rnn_decoder

                #refer to https://github.com/tensorflow/tensorflow/blob/r1.0/tensorflow/contrib/seq2seq/python/ops/decoder_fn.py#L182

                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(  #nference function for a sequence-to-sequence model. It should be used when dynamic_rnn_decoder is in the inference mode.final mode
                    output_fn=
                    output_fn,  #this returns a decoder function . This function in used inside the dynamicRNN function
                    encoder_state=self.encoder_state,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.vocab_size,
                )
            else:

                # attention_states: size [batch_size, max_time, num_units]
                attention_states = tf.transpose(self.encoder_outputs, [
                    1, 0, 2
                ])  #take the attention status as the encorder hidden states

                (
                    attention_keys,  #Each Encoder hidden status multiplied in fully conected way and list of size [num units*Max_time] 
                    attention_values,  #this is attention encoder states 
                    attention_score_fn,  #score function of the attention Different ways to compute attention scores  If we input the decoder state , encoder hidden states  this will out put the context vector 
                    attention_construct_fn
                ) = seq2seq.prepare_attention(  #this contruct will Function to compute attention vectors. This will output the concatanaded context vector and the attention wuary then make it as a inpit 
                    attention_states=attention_states,
                    attention_option="bahdanau",
                    num_units=self.decoder_hidden_units,
                )
                print("Prininting the number of units .......................")
                print(self.decoder_hidden_units)
                print(
                    "Printing the shape of the attetniton values ......................**********************************************"
                )
                print(attention_keys)
                print(
                    "Printing the attention score function++++++++++++++++++++++++++++++++++++++++++++++++++++"
                )
                print(attention_score_fn)

                #this function can basically initialize input state of the decoder the nthe attention and other stuff then this will be passed to dy_decorder
                #decorder_function train will take time, cell_state, cell_input, cell_output, context_state
                decoder_fn_train = seq2seq.attention_decoder_fn_train(  #this is for training the dynamic decorder. This will take care of 
                    encoder_state=self.
                    encoder_state,  # final state. We take the biderection and concatanate it (c or h)
                    attention_keys=
                    attention_keys,  # The transformation of each encoder outputs 
                    attention_values=
                    attention_values,  #attention encododr status 
                    attention_score_fn=
                    attention_score_fn,  #this will give a context vector
                    attention_construct_fn=
                    attention_construct_fn,  #calculating above thinhs  also output the hidden state 
                    name='attention_decoder')
                #What can we achieve by running decorder_fn_ ?  done, next state, next input, emit output, next context state
                #here the emit_output or cell_output will give the output of cell after all atention - non lieanrity applied

                #this also give the hidden vector output which was concatanated with rnn output and attention vector . Actually concatanated goes throug a linear unit
                #next_input = array_ops.concat([cell_input, attention], 1)  #next cell input
                #context_state - this will modify when using the beam search
                #what is the contect state in decorder_fn inside the return funfction of the decorder fn train
                #the following function is same as the above but the only difference is it's use this in the inference .This has a greedy output

                #in the inference model cell_output = output_fn(cell_output) . Which means we get logits
                #next_input = array_ops.concat([cell_input, attention], 1)

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(  #this is used in the inference model 
                    output_fn=
                    output_fn,  #this will predict the output and the narcmax after that attention will be concatenaded 
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,  #doing same 
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.vocab_size,
                )


#following function is to do all the decodinf with the helop of above functions
#this can use in traning or inferense . But we need two separate finctions for trainin and iference

#What is this context_state_train : one way to diversify the inference output is to use a stochastic decoder_fn, in which case one would want to store the  decoded outputs, not just the RNN outputs. This can be done by maintaining a TensorArray in context_state and storing the decoded output of each iteration therein

            (
                self.
                decoder_outputs_train,  #outputs from the eacah cell [batch_size, max_time, cell.output_size]
                self.
                decoder_state_train,  #The final state and will be shaped [batch_size, cell.state_size]
                self.decoder_context_state_train
            ) = (  #described above 
                seq2seq.dynamic_rnn_decoder(
                    cell=self.decoder_cell,
                    decoder_fn=
                    decoder_fn_train,  #decoder_fn allows modeling of early stopping, output, state, and next input and context.
                    inputs=self.
                    decoder_train_inputs_embedded,  #inputs to the decoder in the training #in the raning time  only 
                    sequence_length=self.
                    decoder_train_length,  #sequence_length is needed at training time, i.e., when inputs is not None, for dynamic unrolling. At test time, when inputs is None, sequence_length is not needed.
                    time_major=
                    True,  #input and output shape should be in [max_time, batch_size, ...]
                    scope=scope,
                ))

            self.decoder_logits_train = output_fn(
                self.decoder_outputs_train
            )  #take the final output hidden status and run them throgh linearl layer #get the argmax
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train,
                axis=-1,
                name='decoder_prediction_train')

            scope.reuse_variables()

            (
                self.
                decoder_logits_inference,  #same as above but no input provided. This will take the predicted things as inputs
                self.decoder_state_inference,
                self.decoder_context_state_inference) = (
                    seq2seq.dynamic_rnn_decoder(
                        cell=self.decoder_cell,
                        decoder_fn=
                        decoder_fn_inference,  #difference decorder fucntion 
                        time_major=True,
                        scope=scope,
                    ))
            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference'
            )  #predicted output at the each time step
Beispiel #27
0
    def _init_decoder(self):
        with tf.variable_scope("Decoder") as scope:

            def output_fn(outputs):
                return tc.layers.fully_connected(outputs,
                                                 self.output_symbol_size,
                                                 activation_fn=None,
                                                 scope=scope)

            if not self.attention:
                decoder_fn_train = seq2seq.simple_decoder_fn_train(
                    encoder_state=self.encoder_state)
                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length),
                    num_decoder_symbols=self.output_symbol_size)
            else:
                (attention_keys, attention_values, attention_score_fn,
                 attention_construct_fn) = seq2seq.prepare_attention(
                     attention_states=self.encoder_outputs,
                     attention_option="bahdanau",
                     num_units=self.decoder_hidden_units,
                 )

                decoder_fn_train = seq2seq.attention_decoder_fn_train(
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    name='attention_decoder')

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length),
                    num_decoder_symbols=self.output_symbol_size,
                )
            if self.is_training:
                (self.decoder_outputs_train, self.decoder_state_train,
                 self.decoder_context_state_train) = (
                     seq2seq.dynamic_rnn_decoder(
                         cell=self.decoder_cell,
                         decoder_fn=decoder_fn_train,
                         inputs=self.decoder_train_inputs_embedded,
                         sequence_length=self.decoder_train_length,
                         time_major=False,
                         scope=scope,
                     ))

                self.decoder_logits_train = output_fn(
                    self.decoder_outputs_train)
                self.decoder_prediction_train = tf.argmax(
                    self.decoder_logits_train,
                    axis=-1,
                    name='decoder_prediction_train')

                scope.reuse_variables()

            (self.decoder_logits_inference, self.decoder_state_inference,
             self.decoder_context_state_inference) = (
                 seq2seq.dynamic_rnn_decoder(
                     cell=self.decoder_cell,
                     decoder_fn=decoder_fn_inference,
                     time_major=False,
                     scope=scope,
                 ))
            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference')
    def __init__(self,
                 vocab_size,
                 embed_size,
                 num_layers,
                 hidden_size,
                 eos,
                 max_len,
                 initial_embed=None):
        super(Seq2SeqAttn, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        # TO DO
        self.EOS = eos
        self.PAD = 0

        # placeholders
        self.encoder_inputs = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(
            shape=(None, ), dtype=tf.int32, name='encoder_inputs_length')
        self.decoder_targets = tf.placeholder(shape=(None, None),
                                              dtype=tf.int32,
                                              name='decoder_targets')
        self.rewards = tf.placeholder(shape=(None, ),
                                      dtype=tf.float32,
                                      name='rewards')
        # self.decoder_targets_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='decoder_targets_length')

        # LSTM cell
        cells = []
        for _ in range(num_layers):
            cells.append(tf.contrib.rnn.LSTMCell(hidden_size))

        self.encoder_cell = MultiRNNCell(cells)
        self.decoder_cell = MultiRNNCell(cells)

        # decoder train feeds
        with tf.name_scope('decoder_feeds'):
            sequence_size, batch_size = tf.unstack(
                tf.shape(self.decoder_targets))
            EOS_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.EOS
            PAD_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.PAD
            self.decoder_train_inputs = tf.concat(
                [EOS_SLICE, self.decoder_targets], axis=0)

            # self.decoder_train_length = self.decoder_targets_length + 1 # one for EOS
            self.decoder_train_length = tf.cast(
                tf.ones(shape=(batch_size, )) * (max_len + 1), tf.int32)

            decoder_train_targets = tf.concat(
                [self.decoder_targets, PAD_SLICE],
                axis=0)  # (seq_len + 1) * batch_size
            # decoder_train_targets_seq_len, _ = tf.unstack(tf.shape(decoder_train_targets))
            # decoder_train_targets_eos_mask = tf.one_hot(
            # self.decoder_targets_length,
            # decoder_train_targets_seq_len,
            # on_value=self.EOS,
            # off_value=self.PAD,
            # dtype=tf.int32
            # ) # batch_size * (seq_len + 1)
            # decoder_train_targets_eos_mask = tf.transpose(decoder_train_targets_eos_mask, [1, 0])
            # decoder_train_targets = tf.add(decoder_train_targets, decoder_train_targets_eos_mask)
            self.decoder_train_targets = decoder_train_targets  # (seq_len + 1) * batch_size with EOS at the end of each sentence
            # loss_weights
            loss_weights = tf.cast(
                tf.cast(self.decoder_train_targets, tf.bool), tf.float32)
            self.loss_weights = tf.transpose(loss_weights, perm=[1, 0])
            # self.loss_weights = tf.ones([batch_size, tf.reduce_max(self.decoder_train_length)], dtype=tf.float32, name="loss_weights")

        # embedding layer
        with tf.variable_scope('embedding'):
            # if initial_embed:
            self.embedding = tf.Variable(initial_embed,
                                         name='matrix',
                                         dtype=tf.float32)
            # else:
            # self.embedding = tf.Variable(tf.random_normal([vocab_size, embed_size], - 0.5 / embed_size, 0.5 / embed_size), name='matrix', dtype=tf.float32)
            self.encoder_inputs_embedded = tf.nn.embedding_lookup(
                self.embedding, self.encoder_inputs)
            self.decoder_train_inputs_embedded = tf.nn.embedding_lookup(
                self.embedding, self.decoder_train_inputs)

        # encoder
        with tf.variable_scope('encoder'):
            self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(
                cell=self.encoder_cell,
                inputs=self.encoder_inputs_embedded,
                sequence_length=self.encoder_inputs_length,
                time_major=True,
                dtype=tf.float32)

        # decoder
        with tf.variable_scope('decoder') as scope:

            def output_fn(outputs):
                return tf.contrib.layers.linear(outputs,
                                                self.vocab_size,
                                                scope=scope)

            attention_states = tf.transpose(
                self.encoder_outputs,
                [1, 0, 2])  # batch_size * seq_len * hidden_size

            (attention_keys, attention_values, attention_score_fn,
             attention_construct_fn) = seq2seq.prepare_attention(
                 attention_states=attention_states,
                 attention_option="bahdanau",
                 num_units=self.hidden_size,
             )

            decoder_fn_train = seq2seq.attention_decoder_fn_train(
                encoder_state=self.encoder_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                name='attention_decoder')

            decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                output_fn=output_fn,
                encoder_state=self.encoder_state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn,
                embeddings=self.embedding,
                start_of_sequence_id=self.EOS,
                end_of_sequence_id=self.EOS,
                maximum_length=35,
                num_decoder_symbols=self.vocab_size,
            )

            self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train = seq2seq.dynamic_rnn_decoder(
                cell=self.decoder_cell,
                decoder_fn=decoder_fn_train,
                inputs=self.decoder_train_inputs_embedded,
                sequence_length=self.decoder_train_length,
                time_major=True,
                scope=scope,
            )

            self.decoder_logits_train = output_fn(self.decoder_outputs_train)
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train,
                axis=-1,
                name='decoder_prediction_train')

            scope.reuse_variables()

            self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference = (
                seq2seq.dynamic_rnn_decoder(
                    cell=self.decoder_cell,
                    decoder_fn=decoder_fn_inference,
                    time_major=True,
                    scope=scope,
                ))

            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference')

        # optimizer
        with tf.name_scope('optimizer'):
            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)
            # self.policy_step = tf.Variable(0, name='policy_step', trainable=False)
            logits = tf.transpose(
                self.decoder_logits_train,
                [1, 0, 2])  # batch_size * sequence_length * vocab_size
            targets = tf.transpose(self.decoder_train_targets, [1, 0])

            logits_inference = tf.transpose(self.decoder_logits_inference,
                                            [1, 0, 2])
            output_prob = tf.reduce_max(tf.nn.softmax(logits_inference),
                                        axis=2)  # batch_size * seq_len
            seq_log_prob = tf.reduce_sum(tf.log(output_prob), axis=1)
            self.policy_loss = -tf.reduce_sum(self.rewards * seq_log_prob)
            self.policy_op = tf.train.AdamOptimizer().minimize(
                self.policy_loss)

            self.loss = seq2seq.sequence_loss(logits=logits,
                                              targets=targets,
                                              weights=self.loss_weights)
            self.train_op = tf.train.AdamOptimizer().minimize(
                self.loss, global_step=self.global_step)
Beispiel #29
0
if __name__ == "__main__":
    import numpy as np
    from tensorflow.contrib import seq2seq
    from tensorflow.contrib import rnn

    sequence_length = 25
    vocab_size = 100
    hidden_dims = 10
    batch_size = 32

    cell = rnn.BasicRNNCell(hidden_dims)
    encoder_state = tf.constant(np.random.randn(batch_size, hidden_dims),
                                dtype=np.float32)
    embeddings = tf.constant(np.random.randn(vocab_size, hidden_dims),
                             dtype=np.float32)
    output_W = tf.transpose(embeddings)  # -- tied embeddings
    output_b = tf.constant(np.random.randn(vocab_size), dtype=np.float32)
    output_projections = (output_W, output_b)
    maximum_length = tf.reduce_max(sequence_length) + 3

    decoder_fn = gumbel_decoder_fn(encoder_state, embeddings,
                                   output_projections, maximum_length)

    outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
        cell, decoder_fn, inputs=None, sequence_length=sequence_length)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        outputs_ = sess.run(outputs)
        print("outputs shape", outputs_.shape)
Beispiel #30
0
    def _build_graph(self):
        # required only for training
        self.targets = tf.placeholder(shape=(None, None),
                                      dtype=tf.int32,
                                      name="decoder_inputs")
        self.targets_length = tf.placeholder(shape=(None, ),
                                             dtype=tf.int32,
                                             name="decoder_inputs_length")
        self.global_step = tf.Variable(0, name="global_step", trainable=False)

        with tf.name_scope("DecoderTrainFeed"):
            sequence_size, batch_size = tf.unstack(tf.shape(self.targets))

            EOS_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.EOS
            PAD_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.PAD

            self.train_inputs = tf.concat([EOS_SLICE, self.targets], axis=0)
            self.train_length = self.targets_length + 1

            train_targets = tf.concat([self.targets, PAD_SLICE], axis=0)
            train_targets_seq_len, _ = tf.unstack(tf.shape(train_targets))
            train_targets_eos_mask = tf.one_hot(self.train_length - 1,
                                                train_targets_seq_len,
                                                on_value=self.EOS,
                                                off_value=self.PAD,
                                                dtype=tf.int32)
            train_targets_eos_mask = tf.transpose(train_targets_eos_mask,
                                                  [1, 0])

            # hacky way using one_hot to put EOS symbol at the end of target sequence
            train_targets = tf.add(train_targets, train_targets_eos_mask)

            self.train_targets = train_targets

            self.loss_weights = tf.ones(
                [batch_size, tf.reduce_max(self.train_length)],
                dtype=tf.float32,
                name="loss_weights")

        with tf.variable_scope("embedding") as scope:
            self.inputs_embedded = tf.nn.embedding_lookup(
                self.embedding_matrix, self.train_inputs)

        with tf.variable_scope("Decoder") as scope:

            def logits_fn(outputs):
                return layers.linear(outputs, self.vocab_size, scope=scope)

            if not self.attention:
                train_fn = seq2seq.simple_decoder_fn_train(
                    encoder_state=self.encoder_state)
                inference_fn = seq2seq.simple_decoder_fn_inference(
                    output_fn=logits_fn,
                    encoder_state=self.encoder_state,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.vocab_size)
            else:

                # attention_states: size [batch_size, max_time, num_units]
                attention_states = tf.transpose(self.encoder_outputs,
                                                [1, 0, 2])

                (attention_keys, attention_values, attention_score_fn,
                 attention_construct_fn) = seq2seq.prepare_attention(
                     attention_states=attention_states,
                     attention_option="bahdanau",
                     num_units=self.decoder_hidden_units)

                train_fn = seq2seq.attention_decoder_fn_train(
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    name="decoder_attention")

                inference_fn = seq2seq.attention_decoder_fn_inference(
                    output_fn=logits_fn,
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.vocab_size)

            (self.train_outputs, self.train_state,
             self.train_context_state) = seq2seq.dynamic_rnn_decoder(
                 cell=self.cell,
                 decoder_fn=train_fn,
                 inputs=self.inputs_embedded,
                 sequence_length=self.train_length,
                 time_major=True,
                 scope=scope)

            self.train_logits = logits_fn(self.train_outputs)
            self.train_prediction = tf.argmax(self.train_logits,
                                              axis=-1,
                                              name="train_prediction")
            self.train_prediction_probabilities = tf.nn.softmax(
                self.train_logits,
                dim=-1,
                name="train_prediction_probabilities")

            scope.reuse_variables()

            (self.inference_logits, self.inference_state,
             self.inference_context_state) = seq2seq.dynamic_rnn_decoder(
                 cell=self.cell,
                 decoder_fn=inference_fn,
                 time_major=True,
                 scope=scope)

            self.inference_prediction = tf.argmax(self.inference_logits,
                                                  axis=-1,
                                                  name="inference_prediction")
            self.inference_prediction_probabilities = tf.nn.softmax(
                self.train_logits,
                dim=-1,
                name="inference_prediction_probabilities")
Beispiel #31
0
    def _init_decoder(self):
        with tf.variable_scope("Decoder") as scope:
            def output_fn(outputs):
                self.test_outputs = outputs
                return tf.contrib.layers.linear(outputs, self.decoder_vocab_size, scope=scope)

            if not self.attention:
                decoder_fn_train = seq2seq.simple_decoder_fn_train(encoder_state=self.encoder_state)
                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    embeddings=self.decoder_embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) + 100,
                    num_decoder_symbols=self.decoder_vocab_size,
                )
            else:

                # attention_states: size [batch_size, max_time, num_units]
                attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2])

                (attention_keys,
                attention_values,
                attention_score_fn,
                attention_construct_fn) = seq2seq.prepare_attention(
                    attention_states=attention_states,
                    attention_option="bahdanau",
                    num_units=self.decoder_hidden_units,
                )

                decoder_fn_train = seq2seq.attention_decoder_fn_train(
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    name='attention_decoder'
                )

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.decoder_embedding_matrix,
                    start_of_sequence_id=self.EOS,
                    end_of_sequence_id=self.EOS,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) + 100,
                    num_decoder_symbols=self.decoder_vocab_size,
                )

            (self.decoder_outputs_train,
             self.decoder_state_train,
             self.decoder_context_state_train) = (
                seq2seq.dynamic_rnn_decoder(
                    cell=self.decoder_cell,
                    decoder_fn=decoder_fn_train,
                    inputs=self.decoder_train_inputs_embedded,
                    sequence_length=self.decoder_train_length,
                    time_major=self.time_major,
                    scope=scope,
                )
            )

            self.decoder_logits_train = output_fn(self.decoder_outputs_train)
            self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_prediction_train')

            scope.reuse_variables()

            (self.decoder_logits_inference,
             self.decoder_state_inference,
             self.decoder_context_state_inference) = (
                seq2seq.dynamic_rnn_decoder(
                    cell=self.decoder_cell,
                    decoder_fn=decoder_fn_inference,
                    time_major=self.time_major,
                    scope=scope,
                )
            )
            self.decoder_prediction_inference = tf.argmax(self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
Beispiel #32
0
    def __init_decoder(self):
        '''Initializes the decoder part of the model.'''
        with tf.variable_scope('decoder') as scope:
            output_fn = lambda outs: layers.linear(
                outs, self.__get_vocab_size(), scope=scope)

            if self.cfg.get('use_attention'):
                attention_states = tf.transpose(self.encoder_outputs,
                                                [1, 0, 2])

                (attention_keys, attention_values, attention_score_fn,
                 attention_construct_fn) = seq2seq.prepare_attention(
                     attention_states=attention_states,
                     attention_option='bahdanau',
                     num_units=self.decoder_cell.output_size)

                decoder_fn_train = seq2seq.attention_decoder_fn_train(
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    name='attention_decoder')

                decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_score_fn=attention_score_fn,
                    attention_construct_fn=attention_construct_fn,
                    embeddings=self.embeddings,
                    start_of_sequence_id=Config.EOS_WORD_IDX,
                    end_of_sequence_id=Config.EOS_WORD_IDX,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.__get_vocab_size())
            else:
                decoder_fn_train = seq2seq.simple_decoder_fn_train(
                    encoder_state=self.encoder_state)
                decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
                    output_fn=output_fn,
                    encoder_state=self.encoder_state,
                    embeddings=self.embeddings,
                    start_of_sequence_id=Config.EOS_WORD_IDX,
                    end_of_sequence_id=Config.EOS_WORD_IDX,
                    maximum_length=tf.reduce_max(self.encoder_inputs_length) +
                    3,
                    num_decoder_symbols=self.__get_vocab_size())

            (self.decoder_outputs_train, self.decoder_state_train,
             self.decoder_context_state_train) = seq2seq.dynamic_rnn_decoder(
                 cell=self.decoder_cell,
                 decoder_fn=decoder_fn_train,
                 inputs=self.decoder_train_inputs_embedded,
                 sequence_length=self.decoder_train_length,
                 time_major=True,
                 scope=scope)

            self.decoder_logits_train = output_fn(self.decoder_outputs_train)
            self.decoder_prediction_train = tf.argmax(
                self.decoder_logits_train,
                axis=-1,
                name='decoder_prediction_traion')

            scope.reuse_variables()

            (self.decoder_logits_inference, decoder_state_inference,
             self.decoder_context_state_inference
             ) = seq2seq.dynamic_rnn_decoder(cell=self.decoder_cell,
                                             decoder_fn=decoder_fn_inference,
                                             time_major=True,
                                             scope=scope)

            self.decoder_prediction_inference = tf.argmax(
                self.decoder_logits_inference,
                axis=-1,
                name='decoder_prediction_inference')