예제 #1
0
    def init_decoder_variable(self):
        # Building decoder_cell and decoder_initial_state
        self.decoder_cell, self.decoder_initial_state = self.build_decoder_cell(
        )

        # Initialize decoder embeddings to have variance=1.
        sqrt3 = math.sqrt(3)  # Uniform(-sqrt(3), sqrt(3)) has variance=1.
        initializer = tf.random_uniform_initializer(-sqrt3,
                                                    sqrt3,
                                                    dtype=self.dtype)

        self.decoder_embeddings = tf.get_variable(
            name='embedding',
            shape=[self.num_decoder_symbols, self.embedding_size],
            initializer=initializer,
            dtype=self.dtype)

        # Input projection layer to feed embedded inputs to the cell
        # ** Essential when use_residual=True to match input/output dims
        input_layer = Dense(self.hidden_units,
                            dtype=self.dtype,
                            name='input_projection')

        # Output projection layer to convert cell_outputs to logits
        output_layer = Dense(self.num_decoder_symbols,
                             name='output_projection')

        if self.mode == 'train':
            # decoder_inputs_embedded: [batch_size, max_time_step + 1, embedding_size]
            self.decoder_inputs_embedded = tf.nn.embedding_lookup(
                params=self.decoder_embeddings, ids=self.decoder_inputs_train)

            # Embedded inputs having gone through input projection layer
            self.decoder_inputs_embedded = input_layer(
                self.decoder_inputs_embedded)

            # Helper to feed inputs for training: read inputs from dense ground truth vectors
            training_helper = seq2seq.TrainingHelper(
                inputs=self.decoder_inputs_embedded,
                sequence_length=self.decoder_inputs_length_train,
                time_major=False,
                name='training_helper')

            training_decoder = seq2seq.BasicDecoder(
                cell=self.decoder_cell,
                helper=training_helper,
                initial_state=self.decoder_initial_state,
                output_layer=output_layer)
            # output_layer=None)

            # Maximum decoder time_steps in current batch
            max_decoder_length = tf.reduce_max(
                self.decoder_inputs_length_train)

            # decoder_outputs_train: BasicDecoderOutput
            #                        namedtuple(rnn_outputs, sample_id)
            # decoder_outputs_train.rnn_output: [batch_size, max_time_step + 1, num_decoder_symbols] if output_time_major=False
            #                                   [max_time_step + 1, batch_size, num_decoder_symbols] if output_time_major=True
            # decoder_outputs_train.sample_id: [batch_size], tf.int32
            (self.decoder_outputs_train, self.decoder_last_state_train,
             self.decoder_outputs_length_train) = (seq2seq.dynamic_decode(
                 decoder=training_decoder,
                 output_time_major=False,
                 impute_finished=True,
                 maximum_iterations=max_decoder_length))

            # More efficient to do the projection on the batch-time-concatenated tensor
            # logits_train: [batch_size, max_time_step + 1, num_decoder_symbols]
            # self.decoder_logits_train = output_layer(self.decoder_outputs_train.rnn_output)
            self.decoder_logits_train = tf.identity(
                self.decoder_outputs_train.rnn_output)
            # Use argmax to extract decoder symbols to emit
            self.decoder_pred_train = tf.argmax(self.decoder_logits_train,
                                                axis=-1,
                                                name='decoder_pred_train')

            # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1]
            masks = tf.sequence_mask(lengths=self.decoder_inputs_length_train,
                                     maxlen=max_decoder_length,
                                     dtype=self.dtype,
                                     name='masks')

            # Computes per word average cross-entropy over a batch
            # Internally calls 'nn_ops.sparse_softmax_cross_entropy_with_logits' by default
            self.loss = seq2seq.sequence_loss(
                logits=self.decoder_logits_train,
                targets=self.decoder_targets_train,
                weights=masks,
                average_across_timesteps=True,
                average_across_batch=True,
            )
            # Training summary for the current batch_loss
            tf.summary.scalar('loss', self.loss)

            # Contruct graphs for minimizing loss
            self.init_optimizer()

        elif self.mode == 'decode':

            # Start_tokens: [batch_size,] `int32` vector
            start_tokens = tf.ones([
                self.batch_size,
            ], tf.int32) * self.start_token
            end_token = self.end_token

            def embed_and_input_proj(inputs):
                return input_layer(
                    tf.nn.embedding_lookup(self.decoder_embeddings, inputs))

            if not self.use_beamsearch_decode:
                # Helper to feed inputs for greedy decoding: uses the argmax of the output
                decoding_helper = seq2seq.GreedyEmbeddingHelper(
                    start_tokens=start_tokens,
                    end_token=end_token,
                    embedding=embed_and_input_proj)
                # Basic decoder performs greedy decoding at each time step
                print("building greedy decoder..")
                inference_decoder = seq2seq.BasicDecoder(
                    cell=self.decoder_cell,
                    helper=decoding_helper,
                    initial_state=self.decoder_initial_state,
                    output_layer=output_layer)
            else:
                # Beamsearch is used to approximately find the most likely translation
                print("building beamsearch decoder..")
                inference_decoder = beam_search_decoder.BeamSearchDecoder(
                    cell=self.decoder_cell,
                    embedding=embed_and_input_proj,
                    start_tokens=start_tokens,
                    end_token=end_token,
                    initial_state=self.decoder_initial_state,
                    beam_width=self.beam_width,
                    output_layer=output_layer,
                )

            (self.decoder_outputs_decode, self.decoder_last_state_decode,
             self.decoder_outputs_length_decode) = (
                 seq2seq.dynamic_decode(
                     decoder=inference_decoder,
                     output_time_major=False,
                     # impute_finished=True,	# error occurs
                     maximum_iterations=self.max_decode_step))

            if not self.use_beamsearch_decode:
                # decoder_outputs_decode.sample_id: [batch_size, max_time_step]
                # Or use argmax to find decoder symbols to emit:
                # self.decoder_pred_decode = tf.argmax(self.decoder_outputs_decode.rnn_output,
                #                                      axis=-1, name='decoder_pred_decode')

                # Here, we use expand_dims to be compatible with the result of the beamsearch decoder
                # decoder_pred_decode: [batch_size, max_time_step, 1] (output_major=False)
                self.decoder_pred_decode = tf.expand_dims(
                    self.decoder_outputs_decode.sample_id, -1)

            else:
                # Use beam search to approximately find the most likely translation
                # decoder_pred_decode: [batch_size, max_time_step, beam_width] (output_major=False)
                self.decoder_pred_decode = self.decoder_outputs_decode.predicted_ids
예제 #2
0
def build_network(
    hparams,
    char2numY,
    inputs,
    dec_inputs,
    keep_prob_=0.5,
):
    if hparams.akara2017 is True:
        _inputs = tf.reshape(inputs, [-1, hparams.input_depth, 1])
        network = build_firstPart_model(_inputs, keep_prob_)

        shape = network.get_shape().as_list()
        data_input_embed = tf.reshape(network,
                                      (-1, hparams.max_time_step, shape[1]))

    # Embedding layers
    with tf.variable_scope("embeddin") as embedding_scope:
        decoder_embedding = tf.Variable(
            tf.random_uniform((len(char2numY), hparams.embed_size), -1.0, 1.0),
            name='dec_embedding')  # +1 to consider <EOD>
        decoder_emb_inputs = tf.nn.embedding_lookup(decoder_embedding,
                                                    dec_inputs)

    with tf.variable_scope("encoding") as encoding_scope:
        if not hparams.bidirectional:

            # Regular approach with LSTM units
            # encoder_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            # encoder_cell = tf.nn.rnn_cell.MultiRNNCell([encoder_cell] * hparams.lstm_layers)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            encoder_cell = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                encoder_cell, inputs=data_input_embed, dtype=tf.float32)

        else:

            # Using a bidirectional LSTM architecture instead
            # enc_fw_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            # enc_bw_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)

            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            stacked_cell_fw = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)],
                state_is_tuple=True)
            stacked_cell_bw = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)],
                state_is_tuple=True)

            ((enc_fw_out, enc_bw_out),
             (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn(
                 cell_fw=stacked_cell_fw,
                 cell_bw=stacked_cell_bw,
                 inputs=data_input_embed,
                 dtype=tf.float32)
            encoder_final_state = []
            for layer in range(hparams.lstm_layers):
                enc_fin_c = tf.concat(
                    (enc_fw_final[layer].c, enc_bw_final[layer].c), 1)
                enc_fin_h = tf.concat(
                    (enc_fw_final[layer].h, enc_bw_final[layer].h), 1)
                encoder_final_state.append(
                    tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h))

            encoder_state = tuple(encoder_final_state)
            encoder_outputs = tf.concat((enc_fw_out, enc_bw_out), 2)

    with tf.variable_scope("decoding") as decoding_scope:

        output_layer = Dense(len(char2numY), use_bias=False)
        decoder_lengths = np.ones(
            (hparams.batch_size), dtype=np.int32) * (hparams.max_time_step + 1)
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            decoder_emb_inputs, decoder_lengths)

        if not hparams.bidirectional:
            # decoder_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            decoder_cells = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])

        else:
            # decoder_cell = tf.contrib.rnn.LSTMCell(2 * hparams.num_units)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(2 * hparams.num_units)
                return lstm

            decoder_cells = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])

        if hparams.use_attention:
            # Create an attention mechanism
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                hparams.num_units *
                2 if hparams.bidirectional else hparams.num_units,
                encoder_outputs,
                memory_sequence_length=None)

            decoder_cells = tf.contrib.seq2seq.AttentionWrapper(
                decoder_cells,
                attention_mechanism,
                attention_layer_size=hparams.attention_size,
                alignment_history=True)

            encoder_state = decoder_cells.zero_state(
                hparams.batch_size, tf.float32).clone(cell_state=encoder_state)

        # Basic Decoder and decode
        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cells,
                                                  training_helper,
                                                  encoder_state,
                                                  output_layer=output_layer)

        dec_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
            decoder, impute_finished=True)

        # dec_outputs, _ = tf.nn.dynamic_rnn(decoder_cell, inputs=decoder_emb_inputs, initial_state=encoder_state)

    logits = dec_outputs.rnn_output  # logits是输入softmax之前的层的,是未进入softmax的概率,就是未归一化的概率

    # Inference
    start_tokens = tf.fill([hparams.batch_size], char2numY['<SOD>'])
    end_token = char2numY['<EOD>']
    if not hparams.use_beamsearch_decode:  # beam search只在预测的时候需要。训练的时候因为知道正确答案,并不需要再进行这个搜索。

        inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            decoder_embedding, start_tokens, end_token)

        # Inference Decoder
        inference_decoder = tf.contrib.seq2seq.BasicDecoder(
            decoder_cells,
            inference_helper,
            encoder_state,
            output_layer=output_layer)
    else:

        encoder_state = tf.contrib.seq2seq.tile_batch(
            encoder_state, multiplier=hparams.beam_width)
        decoder_initial_state = decoder_cells.zero_state(
            hparams.batch_size * hparams.beam_width,
            tf.float32).clone(cell_state=encoder_state)

        inference_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=decoder_cells,
            embedding=decoder_embedding,
            start_tokens=start_tokens,
            end_token=end_token,
            initial_state=decoder_initial_state,
            beam_width=hparams.beam_width,
            output_layer=output_layer)

    # Dynamic decoding
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        inference_decoder,
        impute_finished=False,
        maximum_iterations=hparams.output_max_length)
    pred_outputs = outputs.sample_id
    if hparams.use_beamsearch_decode:
        # [batch_size, max_time_step, beam_width]
        pred_outputs = pred_outputs[0]
    return logits, pred_outputs, _final_state
예제 #3
0
    def _testDynamicDecodeRNN(self,
                              time_major,
                              has_attention,
                              with_alignment_history=False):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers_core.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
        beam_width = 3

        with self.cached_session() as sess:
            batch_size_tensor = constant_op.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            initial_state = cell.zero_state(batch_size, dtypes.float32)
            coverage_penalty_weight = 0.0
            if has_attention:
                coverage_penalty_weight = 0.2
                inputs = array_ops.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    num_units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=with_alignment_history)
            cell_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size_tensor *
                                         beam_width)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                embedding=embedding,
                start_tokens=array_ops.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0,
                coverage_penalty_weight=coverage_penalty_weight)

            final_outputs, final_state, final_sequence_lengths = (
                decoder.dynamic_decode(bsd,
                                       output_time_major=time_major,
                                       maximum_iterations=max_out))

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertIsInstance(
                final_outputs,
                beam_search_decoder.FinalBeamSearchDecoderOutput)
            self.assertIsInstance(final_state,
                                  beam_search_decoder.BeamSearchDecoderState)

            beam_search_decoder_output = final_outputs.beam_search_decoder_output
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                sess_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
예제 #4
0
    def build_decoder(self):
        print("building decoder and attention..")
        with tf.variable_scope('decoder'):
            # Building decoder_cell and decoder_initial_state
            self.decoder_cell, self.decoder_initial_state = self.build_decoder_cell(
            )

            input_layer = Dense(self.hidden_units,
                                dtype=self.dtype,
                                name='input_projection')
            # Output projection layer to convert cell_outputs to logits
            output_layer = Dense(self.num_decoder_symbols,
                                 name='output_projection')

            if self.mode == 'train':
                initializer = tf.random_uniform_initializer(-math.sqrt(3),
                                                            math.sqrt(3),
                                                            dtype=self.dtype)
                self.decoder_embeddings = tf.get_variable(
                    name='embedding',
                    shape=[self.num_decoder_symbols, self.embedding_size],
                    initializer=initializer,
                    dtype=self.dtype)
                self.decoder_encoded = tf.nn.embedding_lookup(
                    params=self.decoder_embeddings,
                    ids=self.decoder_inputs_train)
                self.decoder_inputs_encoded = input_layer(self.decoder_encoded)
                print(" Decoder input encoded is ",
                      self.decoder_inputs_encoded.shape)

                # Helper to feed inputs for training: read inputs from dense ground truth vectors
                training_helper = seq2seq.TrainingHelper(
                    inputs=self.decoder_inputs_encoded,
                    sequence_length=self.decoder_inputs_length_train,
                    time_major=False,
                    name='training_helper')

                training_decoder = seq2seq.BasicDecoder(
                    cell=self.decoder_cell,
                    helper=training_helper,
                    initial_state=self.decoder_initial_state,
                    output_layer=output_layer)

                (self.decoder_outputs_train, self.decoder_last_state_train,
                 self.decoder_outputs_length_train) = (seq2seq.dynamic_decode(
                     decoder=training_decoder,
                     output_time_major=False,
                     impute_finished=True,
                     swap_memory=True,
                     maximum_iterations=self.max_decoder_length))

                # More efficient to do the projection on the batch-time-concatenated tensor
                # logits_train: [batch_size, max_time_step + 1, num_decoder_symbols]
                # self.decoder_logits_train = output_layer(self.decoder_outputs_train.rnn_output)
                self.decoder_logits_train = tf.identity(
                    self.decoder_outputs_train.rnn_output)
                # Use argmax to extract decoder symbols to emit
                self.decoder_pred_train = tf.argmax(self.decoder_logits_train,
                                                    axis=-1,
                                                    name='decoder_pred_train')

                # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1]
                masks = tf.sequence_mask(
                    lengths=self.decoder_inputs_length_masks,
                    maxlen=self.max_decoder_length,
                    dtype=self.dtype,
                    name='masks')

                print("logits train shape is ",
                      self.decoder_logits_train.shape)
                print("decoder_targets_train train shape is ",
                      self.decoder_targets_train.shape)

                self.loss = tf.reduce_sum(
                    seq2seq.sequence_loss(
                        logits=self.decoder_logits_train,
                        targets=self.decoder_targets_train,
                        weights=masks,
                        average_across_timesteps=False,
                        average_across_batch=True,
                    ))

                # Compute predictions
                self.accuracy, self.accuracy_op = tf.metrics.accuracy(
                    labels=self.decoder_targets_train,
                    predictions=self.decoder_pred_train,
                    name="accuracy")

                # Training summary for the current batch_loss
                tf.summary.scalar('loss', self.loss)
                tf.summary.scalar('teacher_forcing_accuracy', self.accuracy)

                # Contruct graphs for minimizing loss
                self.init_optimizer()

            elif self.mode == 'decode':
                self.decoder_embeddings = tf.get_variable(
                    name='embedding',
                    shape=[self.num_decoder_symbols, self.embedding_size],
                    dtype=self.dtype)

                # Start_tokens: [batch_size,] `int32` vector
                start_tokens = tf.ones([
                    self.batch_size,
                ], tf.int32) * self.dest_start_token_index
                end_token = self.dest_eos_token_index

                def embed_and_input_proj(inputs):
                    encoded_input = tf.nn.embedding_lookup(
                        self.decoder_embeddings, inputs)
                    return input_layer(encoded_input)

                if not self.use_beamsearch_decode:
                    # Helper to feed inputs for greedy decoding: uses the argmax of the output
                    decoding_helper = seq2seq.GreedyEmbeddingHelper(
                        start_tokens=start_tokens,
                        end_token=end_token,
                        embedding=embed_and_input_proj)
                    # Basic decoder performs greedy decoding at each time step
                    print("building greedy decoder..")
                    inference_decoder = seq2seq.BasicDecoder(
                        cell=self.decoder_cell,
                        helper=decoding_helper,
                        initial_state=self.decoder_initial_state,
                        output_layer=output_layer)
                else:
                    # Beamsearch is used to approximately find the most likely translation
                    print("building beamsearch decoder..")
                    inference_decoder = beam_search_decoder.BeamSearchDecoder(
                        cell=self.decoder_cell,
                        embedding=embed_and_input_proj,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=self.decoder_initial_state,
                        beam_width=self.beam_width,
                        output_layer=output_layer,
                    )

                (self.decoder_outputs_decode, self.decoder_last_state_decode,
                 self.decoder_outputs_length_decode) = (seq2seq.dynamic_decode(
                     decoder=inference_decoder,
                     output_time_major=False,
                     swap_memory=True,
                     maximum_iterations=self.max_decode_step))

                if not self.use_beamsearch_decode:
                    # Here, we use expand_dims to be compatible with the result of the beamsearch decoder
                    # decoder_pred_decode: [batch_size, max_time_step, 1] (output_major=False)
                    self.decoder_pred_decode = tf.expand_dims(
                        self.decoder_outputs_decode.sample_id, -1)

                else:
                    # Use beam search to approximately find the most likely translation
                    # decoder_pred_decode: [batch_size, max_time_step, beam_width] (output_major=False)
                    self.decoder_pred_decode = self.decoder_outputs_decode.predicted_ids
예제 #5
0
    def build_decoder(self, keep_prob):
        print("build decoder")
        with tf.variable_scope("decoder"):
            decoder_cell, decoder_init_state = self.build_decoder_cell(
                config.hidden_size, config.num_layers, keep_prob)

            initializer = tf.random_uniform_initializer(minval=-0.1,
                                                        maxval=0.1)
            self.decoder_embedding = tf.get_variable(
                name="decoder_embedding",
                shape=[config.decoder_vocab_size, config.embedding_size],
                initializer=initializer,
                dtype=tf.float32)
            # tf.summary.histogram("decoder_embed", self.decoder_embedding)

            output_layer = Dense(config.decoder_vocab_size)
            # input_layer = Dense(config.hidden_size, dtype=tf.float32)

            if self.mode == "train":
                # decoder_inputs_embedded: [n, max_time_step+1, embedding_size]
                decoder_inputs_embedded = tf.nn.embedding_lookup(
                    self.decoder_embedding, self.decoder_inputs_train)
                #self.decoder_inputs_embedded = input_layer(self.decoder_inputs_embedded)

                train_helper = TrainingHelper(
                    inputs=decoder_inputs_embedded,
                    sequence_length=self.decoder_length_train,
                    time_major=False,
                    name="traing_helper")
                train_decoder = BasicDecoder(cell=decoder_cell,
                                             helper=train_helper,
                                             initial_state=decoder_init_state,
                                             output_layer=output_layer)

                self.max_decoder_length = tf.reduce_max(
                    self.decoder_length_train)

                self.decoder_outputs_train, self.decoder_last_state_train, self.decoder_outputs_length_train = dynamic_decode(
                    decoder=train_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=self.max_decoder_length)

                self.decoder_logits_train = tf.identity(
                    self.decoder_outputs_train.rnn_output)
                self.pred = tf.argmax(self.decoder_logits_train, axis=-1)

                # mask for true and padded time steps. shape=[batch, max_length+1]
                self.masks = tf.sequence_mask(
                    lengths=self.decoder_length_train,
                    maxlen=self.max_decoder_length,
                    dtype=tf.float32,
                    name="masks")

                # decoder_logits_train: [batch_size, max_time_step + 1, num_decoder_symbols]
                # decoder_targets_train: [batch_size, max_time_steps + 1]
                self.loss = sequence_loss(logits=self.decoder_logits_train,
                                          targets=self.decoder_targets_train,
                                          weights=self.masks,
                                          average_across_timesteps=True,
                                          average_across_batch=True)

            elif self.mode == "decode":
                start_tokens = tf.ones([
                    self.batch,
                ], dtype=tf.int32) * config._GO
                end_token = config._EOS

                def embed_input(inputs):
                    return tf.nn.embedding_lookup(self.decoder_embedding,
                                                  inputs)

                if self.beam_search:
                    pred_decoder = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embed_input,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_init_state,
                        beam_width=self.beam_with,
                        output_layer=output_layer,
                    )

                else:
                    decoding_helper = GreedyEmbeddingHelper(
                        start_tokens=start_tokens,
                        end_token=end_token,
                        embedding=embed_input)
                    pred_decoder = BasicDecoder(
                        cell=decoder_cell,
                        helper=decoding_helper,
                        initial_state=decoder_init_state,
                        output_layer=output_layer)

                self.decoder_outputs_decode, self.decoder_last_state_decode, self.decoder_outputs_length_decode = dynamic_decode(
                    decoder=pred_decoder,
                    output_time_major=False,
                    maximum_iterations=52)

                if self.beam_search:
                    self.pred_id = self.decoder_outputs_decode.predicted_ids
                else:
                    self.pred_id = tf.expand_dims(
                        self.decoder_outputs_decode.sample_id, -1)

                self.shape = tf.shape(self.pred_id)
                if isinstance(self.decoder_last_state_decode, tf.Tensor):
                    self.ls_shape = tf.shape(self.decoder_last_state_decode)
                else:
                    print("not a tensor")
예제 #6
0
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
                   target_dict_dim, is_generating, beam_size,
                   max_generation_length):
    src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
    src_sequence_length = tf.placeholder(tf.int32, shape=[
        None,
    ])

    src_embedding_weights = tf.get_variable("source_word_embeddings",
                                            [source_dict_dim, embedding_dim])
    src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)

    src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    # no peephole
    encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=src_forward_cell,
        cell_bw=src_reversed_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        dtype=tf.float32)

    # concat the forward outputs and backward outputs
    encoded_vec = tf.concat(encoder_outputs, axis=2)

    # project the encoder outputs to size of decoder lstm
    encoded_proj = tf.contrib.layers.fully_connected(inputs=tf.reshape(
        encoded_vec, shape=[-1, embedding_dim * 2]),
                                                     num_outputs=decoder_size,
                                                     activation_fn=None,
                                                     biases_initializer=None)
    encoded_proj_reshape = tf.reshape(
        encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size])

    # get init state for decoder lstm's H
    backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
    decoder_boot = tf.contrib.layers.fully_connected(inputs=tf.reshape(
        backword_first, shape=[-1, embedding_dim]),
                                                     num_outputs=decoder_size,
                                                     activation_fn=tf.nn.tanh,
                                                     biases_initializer=None)

    # prepare the initial state for decoder lstm
    cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32)
    initial_state = LSTMStateTuple(cell_init, decoder_boot)

    # create decoder lstm cell
    decoder_cell = LSTMCellWithSimpleAttention(
        decoder_size,
        encoded_vec if not is_generating else seq2seq.tile_batch(
            encoded_vec, beam_size),
        encoded_proj_reshape if not is_generating else seq2seq.tile_batch(
            encoded_proj_reshape, beam_size),
        src_sequence_length if not is_generating else seq2seq.tile_batch(
            src_sequence_length, beam_size),
        forget_bias=0.0)

    output_layer = Dense(target_dict_dim, name='output_projection')

    if not is_generating:
        trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
        trg_sequence_length = tf.placeholder(tf.int32, shape=[
            None,
        ])
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])
        trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
                                               trg_word_idx)

        training_helper = seq2seq.TrainingHelper(
            inputs=trg_embedding,
            sequence_length=trg_sequence_length,
            time_major=False,
            name='training_helper')

        training_decoder = seq2seq.BasicDecoder(cell=decoder_cell,
                                                helper=training_helper,
                                                initial_state=initial_state,
                                                output_layer=output_layer)

        # get the max length of target sequence
        max_decoder_length = tf.reduce_max(trg_sequence_length)

        decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
            decoder=training_decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=max_decoder_length)

        decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)
        decoder_pred_train = tf.argmax(decoder_logits_train,
                                       axis=-1,
                                       name='decoder_pred_train')
        masks = tf.sequence_mask(lengths=trg_sequence_length,
                                 maxlen=max_decoder_length,
                                 dtype=tf.float32,
                                 name='masks')

        # place holder of label sequence
        lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None])

        # compute the loss
        loss = seq2seq.sequence_loss(logits=decoder_logits_train,
                                     targets=lbl_word_idx,
                                     weights=masks,
                                     average_across_timesteps=True,
                                     average_across_batch=True)

        # return feeding list and loss operator
        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length,
            'trg_word_idx': trg_word_idx,
            'trg_sequence_length': trg_sequence_length,
            'lbl_word_idx': lbl_word_idx
        }, loss
    else:
        start_tokens = tf.ones([
            tf.shape(src_word_idx)[0],
        ], tf.int32) * START_TOKEN_IDX
        # share the same embedding weights with target word
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])

        inference_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=decoder_cell,
            embedding=lambda tokens: tf.nn.embedding_lookup(
                trg_embedding_weights, tokens),
            start_tokens=start_tokens,
            end_token=END_TOKEN_IDX,
            initial_state=tf.nn.rnn_cell.LSTMStateTuple(
                tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size),
                tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)),
            beam_width=beam_size,
            output_layer=output_layer)

        decoder_outputs_decode, _, _ = seq2seq.dynamic_decode(
            decoder=inference_decoder,
            output_time_major=False,
            #impute_finished=True,# error occurs
            maximum_iterations=max_generation_length)

        predicted_ids = decoder_outputs_decode.predicted_ids

        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length
        }, predicted_ids
    def _testDynamicDecodeRNN(self, time_major, has_attention):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers_core.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
        beam_width = 3

        with self.test_session() as sess:
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(cell_depth)
            if has_attention:
                inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    num_units=attention_depth,
                    memory=inputs,
                    memory_sequence_length=encoder_sequence_length)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_size=attention_depth,
                    alignment_history=False)
            cell_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size * beam_width)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                embedding=embedding,
                start_tokens=batch_size * [start_token],
                end_token=end_token,
                initial_state=cell_state,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0)

            final_outputs, final_state = decoder.dynamic_decode(
                bsd, output_time_major=time_major, maximum_iterations=max_out)

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertTrue(
                isinstance(final_outputs,
                           beam_search_decoder.FinalBeamSearchDecoderOutput))
            self.assertTrue(
                isinstance(final_state,
                           beam_search_decoder.BeamSearchDecoderState))

            beam_search_decoder_output = final_outputs.beam_search_decoder_output
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs': final_outputs,
                'final_state': final_state
            })

            # Mostly a smoke test
            time_steps = max_out
            self.assertEqual(
                _t((batch_size, time_steps, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, time_steps, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
예제 #8
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
예제 #9
0
def build_network(
    hparams,
    char2numY,
    inputs,
    dec_inputs,
    keep_prob_=0.5,
):

    if hparams.akara2017 is True:
        _inputs = tf.reshape(inputs, [-1, hparams.input_depth, 1])
        network = build_firstPart_model(_inputs, keep_prob_)
        shape = network.get_shape().as_list()
        data_input_embed = tf.reshape(network,
                                      (-1, hparams.max_time_step, shape[1]))
    else:
        _inputs = tf.reshape(
            inputs,
            [-1, hparams.n_channels, hparams.input_depth / hparams.n_channels])

        conv1 = tf.layers.conv1d(
            inputs=_inputs,
            filters=32,
            kernel_size=2,
            strides=1,
            padding="same",
            activation=tf.nn.relu,
        )
        max_pool_1 = tf.layers.max_pooling1d(inputs=conv1,
                                             pool_size=2,
                                             strides=2,
                                             padding="same")

        conv2 = tf.layers.conv1d(
            inputs=max_pool_1,
            filters=64,
            kernel_size=2,
            strides=1,
            padding="same",
            activation=tf.nn.relu,
        )
        max_pool_2 = tf.layers.max_pooling1d(inputs=conv2,
                                             pool_size=2,
                                             strides=2,
                                             padding="same")

        conv3 = tf.layers.conv1d(
            inputs=max_pool_2,
            filters=128,
            kernel_size=2,
            strides=1,
            padding="same",
            activation=tf.nn.relu,
        )
        max_pool_3 = tf.layers.max_pooling1d(inputs=conv3,
                                             pool_size=2,
                                             strides=2,
                                             padding="same")

        shape = max_pool_3.get_shape().as_list()
        data_input_embed = tf.reshape(
            max_pool_3, (-1, hparams.max_time_step, shape[1] * shape[2]))

    # timesteps = max_time
    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)
    # lstm_size = 128
    # # Get lstm cell output
    # # Add LSTM layers
    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)
    # data_input_embed = tf.stack(data_input_embed, 1)
    # shape = data_input_embed.get_shape().as_list()
    # embed_size = 10 #128 lstm_size # shape[1]*shape[2]

    # Embedding layers
    with tf.variable_scope("embeddin") as embedding_scope:
        decoder_embedding = tf.Variable(
            tf.random_uniform((len(char2numY), hparams.embed_size), -1.0, 1.0),
            name="dec_embedding",
        )  # +1 to consider <EOD>
        decoder_emb_inputs = tf.nn.embedding_lookup(decoder_embedding,
                                                    dec_inputs)

    with tf.variable_scope("encoding") as encoding_scope:
        if not hparams.bidirectional:

            # Regular approach with LSTM units
            # encoder_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            # encoder_cell = tf.nn.rnn_cell.MultiRNNCell([encoder_cell] * hparams.lstm_layers)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            encoder_cell = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                encoder_cell, inputs=data_input_embed, dtype=tf.float32)

        else:

            # Using a bidirectional LSTM architecture instead
            # enc_fw_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            # enc_bw_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)

            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            stacked_cell_fw = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)],
                state_is_tuple=True)
            stacked_cell_bw = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)],
                state_is_tuple=True)

            (
                (enc_fw_out, enc_bw_out),
                (enc_fw_final, enc_bw_final),
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=stacked_cell_fw,
                cell_bw=stacked_cell_bw,
                inputs=data_input_embed,
                dtype=tf.float32,
            )
            encoder_final_state = []
            for layer in range(hparams.lstm_layers):
                enc_fin_c = tf.concat(
                    (enc_fw_final[layer].c, enc_bw_final[layer].c), 1)
                enc_fin_h = tf.concat(
                    (enc_fw_final[layer].h, enc_bw_final[layer].h), 1)
                encoder_final_state.append(
                    tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h))

            encoder_state = tuple(encoder_final_state)
            encoder_outputs = tf.concat((enc_fw_out, enc_bw_out), 2)

    with tf.variable_scope("decoding") as decoding_scope:

        output_layer = Dense(len(char2numY), use_bias=False)
        decoder_lengths = np.ones(
            (hparams.batch_size), dtype=np.int32) * (hparams.max_time_step + 1)
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            decoder_emb_inputs, decoder_lengths)

        if not hparams.bidirectional:
            # decoder_cell = tf.contrib.rnn.LSTMCell(hparams.num_units)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(hparams.num_units)
                return lstm

            decoder_cells = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])

        else:
            # decoder_cell = tf.contrib.rnn.LSTMCell(2 * hparams.num_units)
            def lstm_cell():
                lstm = tf.contrib.rnn.LSTMCell(2 * hparams.num_units)
                return lstm

            decoder_cells = tf.contrib.rnn.MultiRNNCell(
                [lstm_cell() for _ in range(hparams.lstm_layers)])

        if hparams.use_attention:
            # Create an attention mechanism
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                hparams.num_units *
                2 if hparams.bidirectional else hparams.num_units,
                encoder_outputs,
                memory_sequence_length=None,
            )

            decoder_cells = tf.contrib.seq2seq.AttentionWrapper(
                decoder_cells,
                attention_mechanism,
                attention_layer_size=hparams.attention_size,
                alignment_history=True,
            )

            encoder_state = decoder_cells.zero_state(
                hparams.batch_size, tf.float32).clone(cell_state=encoder_state)

        # Basic Decoder and decode
        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cells,
                                                  training_helper,
                                                  encoder_state,
                                                  output_layer=output_layer)

        (
            dec_outputs,
            _final_state,
            _final_sequence_lengths,
        ) = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True)

        # dec_outputs, _ = tf.nn.dynamic_rnn(decoder_cell, inputs=decoder_emb_inputs, initial_state=encoder_state)

    logits = dec_outputs.rnn_output

    # Inference
    start_tokens = tf.fill([hparams.batch_size], char2numY["<SOD>"])
    end_token = char2numY["<EOD>"]
    if not hparams.use_beamsearch_decode:

        inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            decoder_embedding, start_tokens, end_token)

        # Inference Decoder
        inference_decoder = tf.contrib.seq2seq.BasicDecoder(
            decoder_cells,
            inference_helper,
            encoder_state,
            output_layer=output_layer)
    else:

        encoder_state = tf.contrib.seq2seq.tile_batch(
            encoder_state, multiplier=hparams.beam_width)
        decoder_initial_state = decoder_cells.zero_state(
            hparams.batch_size * hparams.beam_width,
            tf.float32).clone(cell_state=encoder_state)

        inference_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=decoder_cells,
            embedding=decoder_embedding,
            start_tokens=start_tokens,
            end_token=end_token,
            initial_state=decoder_initial_state,
            beam_width=hparams.beam_width,
            output_layer=output_layer,
        )

    # Dynamic decoding
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        inference_decoder,
        impute_finished=False,
        maximum_iterations=hparams.output_max_length,
    )
    pred_outputs = outputs.sample_id
    if hparams.use_beamsearch_decode:
        # [batch_size, max_time_step, beam_width]
        pred_outputs = pred_outputs[0]

    return logits, pred_outputs, _final_state
    def build_decoder(self):
        print("building decoder and attention..")
        with tf.variable_scope('decoder'):
            self.decoder_cell, self.decoder_initial_state = self.build_decoder_cell()

            initializer = tf.contrib.layers.xavier_initializer(seed=0, dtype=self.dtype)
            
            self.decoder_embeddings = tf.get_variable(name='embedding',
                shape=[self.num_decoder_symbols, self.decoder_embedding_size],
                initializer=initializer, dtype=self.dtype)

            input_layer = Dense(self.decoder_hidden_units, dtype=self.dtype, name='input_projection')
            output_layer = Dense(self.num_decoder_symbols, name='output_projection')

            if self.mode == 'train':
                self.decoder_inputs_embedded = tf.nn.embedding_lookup(
                    params=self.decoder_embeddings, ids=self.decoder_inputs_train)
               
                self.decoder_inputs_embedded = input_layer(self.decoder_inputs_embedded)

                training_helper = seq2seq.TrainingHelper(inputs=self.decoder_inputs_embedded,
                                                   sequence_length=self.decoder_inputs_length_train,
                                                   time_major=False,
                                                   name='training_helper')

                training_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
                                                   helper=training_helper,
                                                   initial_state=self.decoder_initial_state,
                                                   output_layer=output_layer)
                                                   #output_layer=None)
                    
                max_decoder_length = tf.reduce_max(self.decoder_inputs_length_train)

                (self.decoder_outputs_train, self.decoder_last_state_train, 
                 self.decoder_outputs_length_train) = (seq2seq.dynamic_decode(
                    decoder=training_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=max_decoder_length))
                 
                self.decoder_logits_train = tf.identity(self.decoder_outputs_train.rnn_output) 
                self.decoder_pred_train = tf.argmax(self.decoder_logits_train, axis=-1,
                                                    name='decoder_pred_train')

                masks = tf.sequence_mask(lengths=self.decoder_inputs_length_train, 
                                         maxlen=max_decoder_length, dtype=self.dtype, name='masks')

                self.loss = seq2seq.sequence_loss(logits=self.decoder_logits_train, 
                                                  targets=self.decoder_targets_train,
                                                  weights=masks,
                                                  average_across_timesteps=True,
                                                  average_across_batch=True,)

                tf.summary.scalar('loss', self.loss)

                # Contruct graphs for minimizing loss
                self.init_optimizer()

            elif self.mode == 'decode':
        
                # Start_tokens: [batch_size,] `int32` vector
                start_tokens = tf.ones([self.batch_size,], tf.int32) * data_utils.start_token
                end_token = data_utils.end_token

                def embed_and_input_proj(inputs):
                    return input_layer(tf.nn.embedding_lookup(self.decoder_embeddings, inputs))
                    
                if not self.use_beamsearch_decode:
                    # Helper to feed inputs for greedy decoding: uses the argmax of the output
                    decoding_helper = seq2seq.GreedyEmbeddingHelper(start_tokens=start_tokens,
                                                                    end_token=end_token,
                                                                    embedding=embed_and_input_proj)
                    # Basic decoder performs greedy decoding at each time step
                    print("building greedy decoder..")
                    inference_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
                                                             helper=decoding_helper,
                                                             initial_state=self.decoder_initial_state,
                                                             output_layer=output_layer)
                else:
                    # Beamsearch is used to approximately find the most likely translation
                    print("building beamsearch decoder..")
                    inference_decoder = beam_search_decoder.BeamSearchDecoder(cell=self.decoder_cell,
                                                               embedding=embed_and_input_proj,
                                                               start_tokens=start_tokens,
                                                               end_token=end_token,
                                                               initial_state=self.decoder_initial_state,
                                                               beam_width=self.beam_width,
                                                               output_layer=output_layer,)
                # For GreedyDecoder, return
                # decoder_outputs_decode: BasicDecoderOutput instance
                #                         namedtuple(rnn_outputs, sample_id)
                # decoder_outputs_decode.rnn_output: [batch_size, max_time_step, num_decoder_symbols] 	if output_time_major=False
                #                                    [max_time_step, batch_size, num_decoder_symbols] 	if output_time_major=True
                # decoder_outputs_decode.sample_id: [batch_size, max_time_step], tf.int32		if output_time_major=False
                #                                   [max_time_step, batch_size], tf.int32               if output_time_major=True 
                
                # For BeamSearchDecoder, return
                # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance
                #                         namedtuple(predicted_ids, beam_search_decoder_output)
                # decoder_outputs_decode.predicted_ids: [batch_size, max_time_step, beam_width] if output_time_major=False
                #                                       [max_time_step, batch_size, beam_width] if output_time_major=True
                # decoder_outputs_decode.beam_search_decoder_output: BeamSearchDecoderOutput instance
                #                                                    namedtuple(scores, predicted_ids, parent_ids)

                (self.decoder_outputs_decode, self.decoder_last_state_decode,
                 self.decoder_outputs_length_decode) = (seq2seq.dynamic_decode(
                    decoder=inference_decoder,
                    output_time_major=False,
                    #impute_finished=True,	# error occurs
                    maximum_iterations=self.max_decode_step))

                ### get alignment from decoder_last_state
                if self.use_attention:
                    self.alignment = self.decoder_last_state_decode[0].alignment_history.stack()
                else:
                    self.alignment = []

                if not self.use_beamsearch_decode:
                    # decoder_outputs_decode.sample_id: [batch_size, max_time_step]
                    # Or use argmax to find decoder symbols to emit:
                    # self.decoder_pred_decode = tf.argmax(self.decoder_outputs_decode.rnn_output,
                    #                                      axis=-1, name='decoder_pred_decode')

                    # Here, we use expand_dims to be compatible with the result of the beamsearch decoder
                    # decoder_pred_decode: [batch_size, max_time_step, 1] (output_major=False)
                    self.decoder_pred_decode = tf.expand_dims(self.decoder_outputs_decode.sample_id, -1)

                else:
                    # Use beam search to approximately find the most likely translation
                    # decoder_pred_decode: [batch_size, max_time_step, beam_width] (output_major=False)
                    self.decoder_pred_decode = self.decoder_outputs_decode.predicted_ids