Exemple #1
0
    def encode_decode(self,
                      encoder_channel_inputs,
                      encoder_attn_masks,
                      decoder_inputs,
                      targets,
                      target_weights,
                      encoder_copy_inputs=None):
        bs_decoding = self.token_decoding_algorithm == 'beam_search' \
            and self.forward_only

        # --- Encode Step --- #
        if bs_decoding:
            targets = graph_utils.wrap_inputs(self.decoder.beam_decoder,
                                              targets)
            encoder_copy_inputs = graph_utils.wrap_inputs(
                self.decoder.beam_decoder, encoder_copy_inputs)
        encoder_outputs, encoder_states = \
            self.encoder.define_graph(encoder_channel_inputs)

        # --- Decode Step --- #
        if self.tg_token_use_attention:
            attention_states = tf.concat([
                tf.reshape(m, [-1, 1, self.encoder.output_dim])
                for m in encoder_outputs
            ],
                                         axis=1)
        else:
            attention_states = None
        num_heads = 2 if (self.tg_token_use_attention and self.copynet) else 1

        output_symbols, sequence_logits, output_logits, states, attn_alignments, \
            pointers = self.decoder.define_graph(
                        encoder_states[-1], decoder_inputs,
                        encoder_attn_masks=encoder_attn_masks,
                        attention_states=attention_states,
                        num_heads=num_heads,
                        encoder_copy_inputs=encoder_copy_inputs)

        # --- Compute Losses --- #
        if not self.forward_only:
            # A. Sequence Loss
            if self.training_algorithm == "standard":
                encoder_decoder_token_loss = self.sequence_loss(
                    output_logits, targets, target_weights,
                    graph_utils.sparse_cross_entropy)
            elif self.training_algorithm == 'beam_search_opt':
                pass
            else:
                raise AttributeError("Unrecognized training algorithm.")

            # B. Attention Regularization
            attention_reg = self.attention_regularization(attn_alignments) \
                if self.tg_token_use_attention else 0

            # C. Character Sequence Loss
            if self.tg_char:
                # re-arrange character inputs
                char_decoder_inputs = [
                    tf.squeeze(x, 1) for x in tf.split(
                        axis=1,
                        num_or_size_splits=self.max_target_token_size + 2,
                        value=tf.concat(axis=0,
                                        values=self.char_decoder_inputs))
                ]
                char_targets = [
                    tf.squeeze(x, 1) for x in tf.split(
                        axis=1,
                        num_or_size_splits=self.max_target_token_size + 1,
                        value=tf.concat(axis=0, values=self.char_targets))
                ]
                char_target_weights = [
                    tf.squeeze(x, 1) for x in tf.split(
                        axis=1,
                        num_or_size_splits=self.max_target_token_size + 1,
                        value=tf.concat(axis=0,
                                        values=self.char_target_weights))
                ]
                if bs_decoding:
                    char_decoder_inputs = graph_utils.wrap_inputs(
                        self.decoder.beam_decoder, char_decoder_inputs)
                    char_targets = graph_utils.wrap_inputs(
                        self.decoder.beam_decoder, char_targets)
                    char_target_weights = graph_utils.wrap_inputs(
                        self.decoder.beam_decoder, char_target_weights)
                # get initial state from decoder output
                char_decoder_init_state = \
                    tf.concat(axis=0, values=[tf.reshape(d_o, [-1, self.decoder.dim])
                                              for d_o in states])
                char_output_symbols, char_sequence_logits, char_output_logits, _, _ = \
                    self.char_decoder.define_graph(
                        char_decoder_init_state, char_decoder_inputs)
                encoder_decoder_char_loss = self.sequence_loss(
                    char_output_logits, char_targets, char_target_weights,
                    graph_utils.softmax_loss(self.char_decoder.output_project,
                                             self.tg_char_vocab_size / 2,
                                             self.tg_char_vocab_size))
            else:
                encoder_decoder_char_loss = 0

            losses = encoder_decoder_token_loss + \
                     self.gamma * encoder_decoder_char_loss + \
                     self.beta * attention_reg
        else:
            losses = tf.zeros_like(decoder_inputs[0])

        # --- Store encoder/decoder output states --- #
        encoder_hidden_states = tf.concat(
            axis=1,
            values=[
                tf.reshape(e_o, [-1, 1, self.encoder.output_dim])
                for e_o in encoder_outputs
            ])

        top_states = []
        if self.rnn_cell == 'gru':
            for state in states:
                top_states.append(state[:, -self.decoder.dim:])
        elif self.rnn_cell == 'lstm':
            for state in states:
                if self.num_layers > 1:
                    top_states.append(state[-1][1])
                else:
                    top_states.append(state[1])
        decoder_hidden_states = tf.concat(axis=1,
                                          values=[
                                              tf.reshape(
                                                  d_o,
                                                  [-1, 1, self.decoder.dim])
                                              for d_o in top_states
                                          ])

        O = {}
        O['output_symbols'] = output_symbols
        O['sequence_logits'] = sequence_logits
        O['losses'] = losses
        O['attn_alignments'] = attn_alignments
        O['encoder_hidden_states'] = encoder_hidden_states
        O['decoder_hidden_states'] = decoder_hidden_states
        if self.tg_char:
            O['char_output_symbols'] = char_output_symbols
            O['char_sequence_logits'] = char_sequence_logits
        if self.use_copy:
            O['pointers'] = pointers
        return O
Exemple #2
0
    def encode_decode(self, encoder_channel_inputs, encoder_attn_masks,
                      decoder_inputs, targets, target_weights):

        encoder_outputs, encoder_states = \
            self.encoder.define_graph(encoder_channel_inputs)
        if self.tg_token_use_attention:
            top_states = [
                tf.reshape(m, [-1, 1, self.encoder.output_dim])
                for m in encoder_outputs
            ]
            attention_states = tf.concat(axis=1, values=top_states)
        else:
            attention_states = None

        num_heads = 2 if (self.tg_token_use_attention
                          and self.use_copynet) else 1

        # --- Run encode-decode steps --- #
        output_symbols, output_logits, outputs, states, attn_alignments, \
            pointers = self.decoder.define_graph(
                        encoder_states[-1], decoder_inputs,
                        encoder_attn_masks=encoder_attn_masks,
                        attention_states=attention_states,
                        num_heads=num_heads,
                        encoder_copy_inputs=self.encoder_full_inputs)

        bs_decoding = self.forward_only and \
                      self.token_decoding_algorithm == 'beam_search'

        # --- Compute Losses --- #

        # A. Sequence Loss
        if self.forward_only or self.training_algorithm == "standard":
            if bs_decoding:
                targets = graph_utils.wrap_inputs(self.decoder.beam_decoder,
                                                  targets)
            if self.use_copynet:
                step_loss_fun = graph_utils.sparse_cross_entropy
            else:
                step_loss_fun = graph_utils.softmax_loss(
                    self.decoder.output_project, self.num_samples,
                    self.target_vocab_size)
            encoder_decoder_token_loss = self.sequence_loss(
                outputs, targets, target_weights, step_loss_fun)
        else:
            raise AttributeError("Unrecognized training algorithm.")

        # B. Attention Regularization
        attention_reg = self.attention_regularization(attn_alignments) \
            if self.tg_token_use_attention else 0

        # C. Supervised Copying Loss (if any)
        if self.use_copy and self.copy_fun == 'supervised':
            if bs_decoding:
                pointer_targets = self.decoder.beam_decoder.wrap_input(
                    self.pointer_targets)
            else:
                pointer_targets = self.pointer_targets
            copy_loss = self.copy_loss(pointers, pointer_targets)
        else:
            copy_loss = 0

        # D. Character Sequence Loss
        if self.tg_char:
            # re-arrange character inputs
            char_decoder_inputs = [
                tf.squeeze(x, 1) for x in tf.split(
                    axis=1,
                    num_or_size_splits=self.max_target_token_size + 2,
                    value=tf.concat(axis=0, values=self.char_decoder_inputs))
            ]
            char_targets = [
                tf.squeeze(x, 1) for x in tf.split(
                    axis=1,
                    num_or_size_splits=self.max_target_token_size + 1,
                    value=tf.concat(axis=0, values=self.char_targets))
            ]
            char_target_weights = [
                tf.squeeze(x, 1) for x in tf.split(
                    axis=1,
                    num_or_size_splits=self.max_target_token_size + 1,
                    value=tf.concat(axis=0, values=self.char_target_weights))
            ]
            if bs_decoding:
                char_decoder_inputs = graph_utils.wrap_inputs(
                    self.decoder.beam_decoder, char_decoder_inputs)
                char_targets = graph_utils.wrap_inputs(
                    self.decoder.beam_decoder, char_targets)
                char_target_weights = graph_utils.wrap_inputs(
                    self.decoder.beam_decoder, char_target_weights)
            # get initial state from decoder output
            char_decoder_init_state = tf.concat(
                axis=0,
                values=[
                    tf.reshape(d_o, [-1, self.decoder.dim]) for d_o in outputs
                ])
            char_output_symbols, char_output_logits, char_outputs, _, _ = \
                self.char_decoder.define_graph(
                    char_decoder_init_state, char_decoder_inputs)
            encoder_decoder_char_loss = self.sequence_loss(
                char_outputs, char_targets, char_target_weights,
                graph_utils.softmax_loss(self.char_decoder.output_project,
                                         self.tg_char_vocab_size / 2,
                                         self.tg_char_vocab_size))
        else:
            encoder_decoder_char_loss = 0

        losses = encoder_decoder_token_loss + \
                 self.gamma * encoder_decoder_char_loss + \
                 self.chi * copy_loss + \
                 self.beta * attention_reg

        # store encoder/decoder output states
        self.encoder_hidden_states = tf.concat(
            axis=1,
            values=[
                tf.reshape(e_o, [-1, 1, self.encoder.output_dim])
                for e_o in encoder_outputs
            ])

        top_states = []
        if self.rnn_cell == 'gru':
            for state in states:
                top_states.append(state[:, -self.decoder.dim:])
        elif self.rnn_cell == 'lstm':
            for state in states:
                if self.num_layers > 1:
                    top_states.append(state[-1][1])
                else:
                    top_states.append(state[1])
        self.decoder_hidden_states = tf.concat(
            axis=1,
            values=[
                tf.reshape(d_o, [-1, 1, self.decoder.dim])
                for d_o in top_states
            ])

        O = [output_symbols, output_logits, losses, attn_alignments]
        if self.tg_char:
            O.append(char_output_symbols)
            O.append(char_output_logits)
        if self.use_copy:
            O.append(pointers)
        return O
Exemple #3
0
    def define_graph(self,
                     encoder_state,
                     decoder_inputs,
                     input_embeddings=None,
                     encoder_attn_masks=None,
                     attention_states=None,
                     num_heads=1,
                     encoder_copy_inputs=None):
        """
        :param encoder_state: Encoder state => initial decoder state.
        :param decoder_inputs: Decoder training inputs ("<START>, ... <EOS>").
        :param input_embeddings: Decoder vocabulary embedding.
        :param encoder_attn_masks: Binary masks whose entries corresponding to non-padding tokens are 1.
        :param attention_states: Encoder hidden states.
        :param num_heads: Number of attention heads.
        :param encoder_copy_inputs: Array of encoder copy inputs where the copied words are represented using target
            vocab indices and place holding indices are used elsewhere.
        :return output_symbols: (batched) discrete output sequences
        :return output_logits: (batched) output sequence scores
        :return outputs: (batched) output states for all steps
        :return states: (batched) hidden states for all steps
        :return attn_alignments: (batched) attention masks (if attention is used)
        """
        if self.use_attention:
            assert (attention_states.get_shape()[1:2].is_fully_defined())
        if encoder_copy_inputs:
            assert (
                attention_states.get_shape()[1] == len(encoder_copy_inputs))
        bs_decoding = self.forward_only and \
                      self.decoding_algorithm == "beam_search"

        if input_embeddings is None:
            input_embeddings = self.embeddings()

        if self.force_reading_input:
            print(
                "Warning: reading ground truth decoder inputs at decoding time."
            )

        with tf.compat.v1.variable_scope(self.scope + "_decoder_rnn") as scope:
            decoder_cell = self.decoder_cell()
            states = []
            alignments_list = []
            pointers = None

            # Cell Wrappers -- 'Attention', 'CopyNet', 'BeamSearch'
            if bs_decoding:
                beam_decoder = self.beam_decoder
                state = beam_decoder.wrap_state(encoder_state,
                                                self.output_project)
            else:
                state = encoder_state
                past_output_symbols = []
                past_output_logits = []

            if self.use_attention:
                if bs_decoding:
                    encoder_attn_masks = graph_utils.wrap_inputs(
                        beam_decoder, encoder_attn_masks)
                    attention_states = beam_decoder.wrap_input(
                        attention_states)
                encoder_attn_masks = [
                    tf.expand_dims(encoder_attn_mask, 1)
                    for encoder_attn_mask in encoder_attn_masks
                ]
                encoder_attn_masks = tf.concat(axis=1,
                                               values=encoder_attn_masks)
                decoder_cell = decoder.AttentionCellWrapper(
                    decoder_cell, attention_states, encoder_attn_masks,
                    self.attention_function, self.attention_input_keep,
                    self.attention_output_keep, num_heads, self.dim,
                    self.num_layers, self.use_copy, self.vocab_size)

            if self.use_copy and self.copy_fun == 'copynet':
                decoder_cell = decoder.CopyCellWrapper(decoder_cell,
                                                       self.output_project,
                                                       self.num_layers,
                                                       encoder_copy_inputs,
                                                       self.vocab_size)

            if bs_decoding:
                decoder_cell = beam_decoder.wrap_cell(decoder_cell,
                                                      self.output_project)

            def step_output_symbol_and_logit(output):
                epsilon = tf.constant(1e-12)
                if self.copynet:
                    output_logits = tf.math.log(output + epsilon)
                else:
                    W, b = self.output_project
                    output_logits = tf.math.log(
                        tf.nn.softmax(tf.matmul(output, W) + b) + epsilon)
                output_symbol = tf.argmax(input=output_logits, axis=1)
                past_output_symbols.append(output_symbol)
                past_output_logits.append(output_logits)
                return output_symbol, output_logits

            for i, input in enumerate(decoder_inputs):
                if bs_decoding:
                    input = beam_decoder.wrap_input(input)

                if i > 0:
                    scope.reuse_variables()
                    if self.forward_only:
                        if self.decoding_algorithm == "beam_search":
                            (
                                past_beam_symbols,  # [batch_size*self.beam_size, max_len], right-aligned!!!
                                past_beam_logprobs,  # [batch_size*self.beam_size]
                                past_cell_states,  # [batch_size*self.beam_size, max_len, dim]
                            ) = state
                            input = past_beam_symbols[:, -1]
                        elif self.decoding_algorithm == "greedy":
                            output_symbol, _ = step_output_symbol_and_logit(
                                output)
                            if not self.force_reading_input:
                                input = tf.cast(output_symbol, dtype=tf.int32)
                    else:
                        step_output_symbol_and_logit(output)
                    if self.copynet:
                        decoder_input = input
                        input = tf.compat.v1.where(
                            input >= self.target_vocab_size,
                            tf.ones_like(input) * data_utils.UNK_ID, input)

                input_embedding = tf.nn.embedding_lookup(
                    params=input_embeddings, ids=input)

                # Appending selective read information for CopyNet
                if self.copynet:
                    attn_length = attention_states.get_shape()[1]
                    attn_dim = attention_states.get_shape()[2]
                    if i == 0:
                        # Append dummy zero vector to the <START> token
                        selective_reads = tf.zeros([self.batch_size, attn_dim])
                        if bs_decoding:
                            selective_reads = beam_decoder.wrap_input(
                                selective_reads)
                    else:
                        encoder_copy_inputs_2d = tf.concat([
                            tf.expand_dims(x, 1) for x in encoder_copy_inputs
                        ],
                                                           axis=1)
                        if self.forward_only:
                            copy_input = tf.compat.v1.where(
                                decoder_input >= self.target_vocab_size,
                                tf.reduce_sum(input_tensor=tf.one_hot(
                                    input - self.target_vocab_size,
                                    depth=attn_length,
                                    dtype=tf.int32) * encoder_copy_inputs_2d,
                                              axis=1), decoder_input)
                        else:
                            copy_input = decoder_input
                        tiled_copy_input = tf.tile(
                            input=tf.reshape(copy_input, [-1, 1]),
                            multiples=np.array([1, attn_length]))
                        # [batch_size(*self.beam_size), max_source_length]
                        selective_mask = tf.cast(tf.equal(
                            tiled_copy_input, encoder_copy_inputs_2d),
                                                 dtype=tf.float32)
                        # [batch_size(*self.beam_size), max_source_length]
                        weighted_selective_mask = tf.nn.softmax(
                            selective_mask * alignments[1])
                        # [batch_size(*self.beam_size), max_source_length, attn_dim]
                        weighted_selective_mask_3d = tf.tile(
                            input=tf.expand_dims(weighted_selective_mask, 2),
                            multiples=np.array([1, 1, attn_dim]))
                        # [batch_size(*self.beam_size), attn_dim]
                        selective_reads = tf.reduce_sum(
                            input_tensor=weighted_selective_mask_3d *
                            attention_states,
                            axis=1)
                    input_embedding = tf.concat(
                        [input_embedding, selective_reads], axis=1)

                if self.copynet:
                    output, state, alignments, attns = \
                        decoder_cell(input_embedding, state)
                    alignments_list.append(alignments)
                elif self.use_attention:
                    output, state, alignments, attns = \
                        decoder_cell(input_embedding, state)
                    alignments_list.append(alignments)
                else:
                    output, state = decoder_cell(input_embedding, state)

                # save output states
                if not bs_decoding:
                    # when doing beam search decoding, the output state of each
                    # step cannot simply be gathered step-wise outside the decoder
                    # (speical case: beam_size = 1)
                    states.append(state)

            if self.use_attention:
                # Tensor list --> tenosr
                attn_alignments = tf.concat(
                    axis=1,
                    values=[tf.expand_dims(x[0], 1) for x in alignments_list])
            if self.copynet:
                pointers = tf.concat(
                    axis=1,
                    values=[tf.expand_dims(x[1], 1) for x in alignments_list])

            if bs_decoding:
                # Beam-search output
                (
                    past_beam_symbols,  # [batch_size*self.beam_size, max_len], right-aligned!!!
                    past_beam_logprobs,  # [batch_size*self.beam_size]
                    past_cell_states,
                ) = state
                # [self.batch_size, self.beam_size, max_len]
                top_k_osbs = tf.reshape(past_beam_symbols[:, 1:],
                                        [self.batch_size, self.beam_size, -1])
                top_k_osbs = tf.split(axis=0,
                                      num_or_size_splits=self.batch_size,
                                      value=top_k_osbs)
                top_k_osbs = [
                    tf.split(axis=0,
                             num_or_size_splits=self.beam_size,
                             value=tf.squeeze(top_k_output, axis=[0]))
                    for top_k_output in top_k_osbs
                ]
                top_k_osbs = [[
                    tf.squeeze(output, axis=[0]) for output in top_k_output
                ] for top_k_output in top_k_osbs]
                # [self.batch_size, self.beam_size]
                top_k_seq_logits = tf.reshape(
                    past_beam_logprobs, [self.batch_size, self.beam_size])
                top_k_seq_logits = tf.split(axis=0,
                                            num_or_size_splits=self.batch_size,
                                            value=top_k_seq_logits)
                top_k_seq_logits = [
                    tf.squeeze(top_k_logit, axis=[0])
                    for top_k_logit in top_k_seq_logits
                ]
                if self.use_attention:
                    attn_alignments = tf.reshape(attn_alignments, [
                        self.batch_size, self.beam_size,
                        len(decoder_inputs),
                        attention_states.get_shape()[1]
                    ])
                # LSTM: ([batch_size*self.beam_size, :, dim],
                #        [batch_size*self.beam_size, :, dim])
                # GRU: [batch_size*self.beam_size, :, dim]
                if self.rnn_cell == 'lstm':
                    if self.num_layers == 1:
                        c_states, h_states = past_cell_states
                        states = list(
                            zip([
                                tf.squeeze(x, axis=[1]) for x in tf.split(
                                    c_states, c_states.get_shape()[1], axis=1)
                            ], [
                                tf.squeeze(x, axis=[1]) for x in tf.split(
                                    h_states, h_states.get_shape()[1], axis=1)
                            ]))
                    else:
                        layered_states = [
                            list(
                                zip([
                                    tf.squeeze(x, axis=[1])
                                    for x in tf.split(c_states,
                                                      c_states.get_shape()[1],
                                                      axis=1)[1:]
                                ], [
                                    tf.squeeze(x, axis=[1])
                                    for x in tf.split(h_states,
                                                      h_states.get_shape()[1],
                                                      axis=1)[1:]
                                ])) for c_states, h_states in past_cell_states
                        ]
                        states = list(zip(layered_states))
                elif self.rnn_cell in ['gru', 'ran']:
                    states = [tf.squeeze(x, axis=[1]) for x in \
                        tf.split(num_or_size_splits=past_cell_states.get_shape()[1],
                                 axis=1, value=past_cell_states)][1:]
                else:
                    raise AttributeError(
                        "Unrecognized rnn cell type: {}".format(self.rnn_cell))
                return top_k_osbs, top_k_seq_logits, states, \
                       states, attn_alignments, pointers
            else:
                # Greedy output
                step_output_symbol_and_logit(output)
                output_symbols = tf.concat(
                    [tf.expand_dims(x, 1) for x in past_output_symbols],
                    axis=1)
                sequence_logits = tf.add_n([
                    tf.reduce_max(input_tensor=x, axis=1)
                    for x in past_output_logits
                ])
                return output_symbols, sequence_logits, past_output_logits, \
                       states, attn_alignments, pointers
Exemple #4
0
    def define_graph(self, encoder_state, decoder_inputs,
                     input_embeddings=None, encoder_attn_masks=None,
                     attention_states=None, num_heads=1,
                     encoder_copy_inputs=None):
        """
        :return output_symbols: batch of discrete output sequences
        :return output_logits: batch of output sequence scores
        :return outputs: batch output states for all steps
        :return states: batch hidden states for all steps
        :return attn_alignments: batch attention masks
                                 (if attention mechanism is used)
        """
        if input_embeddings is None:
            input_embeddings = self.embeddings()

        if self.use_attention and \
                not attention_states.get_shape()[1:2].is_fully_defined():
            raise ValueError("Shape [1] and [2] of attention_states must be "
                             "known %s" % attention_states.get_shape())

        bs_decoding = self.forward_only and \
                      self.decoding_algorithm == "beam_search"

        if self.force_reading_input:
            print("Warning: reading ground truth decoder inputs at decoding time.")

        with tf.variable_scope(self.scope + "_decoder_rnn") as scope:
            decoder_cell = self.decoder_cell()
            outputs = []
            states = []
            alignments_list = []

            # applying cell wrappers: ["attention", "beam"]
            if bs_decoding:
                beam_decoder = self.beam_decoder
                state = beam_decoder.wrap_state(
                    encoder_state, self.output_project)
                encoder_copy_inputs = graph_utils.wrap_inputs(
                    beam_decoder, encoder_copy_inputs)
            else:
                state = encoder_state
                past_output_symbols = []
                past_output_logits = tf.cast(decoder_inputs[0] * 0, tf.float32)

            if self.use_attention:
                if bs_decoding:
                    encoder_attn_masks = graph_utils.wrap_inputs(
                        beam_decoder, encoder_attn_masks)
                    attention_states = beam_decoder.wrap_input(attention_states)
                encoder_attn_masks = [tf.expand_dims(encoder_attn_mask, 1)
                    for encoder_attn_mask in encoder_attn_masks]
                encoder_attn_masks = tf.concat(axis=1, values=encoder_attn_masks)
                decoder_cell = decoder.AttentionCellWrapper(
                    decoder_cell, attention_states, encoder_attn_masks,
                    encoder_copy_inputs, self.attention_function,
                    self.attention_input_keep, self.attention_output_keep,
                    num_heads, self.dim, self.num_layers, self.use_copy,
                    self.vocab_size)

            if self.use_copy and self.copy_fun != 'supervised':
                decoder_cell = decoder.CopyCellWrapper(decoder_cell,
                    self.output_project, self.num_layers, encoder_copy_inputs,
                    self.vocab_size, self.generation_mask)

            if bs_decoding:
                decoder_cell = beam_decoder.wrap_cell(
                    decoder_cell, self.output_project)

            for i, input in enumerate(decoder_inputs):
                if bs_decoding:
                    input = beam_decoder.wrap_input(input)

                if i > 0:
                    scope.reuse_variables()
                    if self.forward_only and not self.force_reading_input:
                        if self.decoding_algorithm == "beam_search":
                            (
                                past_cand_symbols,  # [batch_size, max_len]
                                past_cand_logprobs, # [batch_size]
                                past_beam_symbols,  # [batch_size*self.beam_size, max_len], right-aligned!!!
                                past_beam_logprobs, # [batch_size*self.beam_size]
                                past_cell_states,   # [batch_size*self.beam_size, max_len, dim]
                            ) = state
                            input = past_beam_symbols[:, -1]
                        elif self.decoding_algorithm == "greedy":
                            if self.use_copy and self.copy_fun != 'supervised':
                                epsilon = tf.constant(1e-12)
                                projected_output = tf.log(output + epsilon)
                            else:
                                W, b = self.output_project
                                projected_output = \
                                    tf.nn.log_softmax(tf.matmul(output, W) + b)
                            output_symbol = tf.argmax(projected_output, 1)
                            past_output_symbols.append(tf.expand_dims(output_symbol, 1))
                            past_output_logits = \
                                tf.add(past_output_logits, tf.reduce_max(projected_output, 1))
                            input = tf.cast(output_symbol, dtype=tf.int32)
                        input = tf.where(input >= self.target_vocab_size,
                            tf.ones(tf.shape(input), dtype=tf.int32) * data_utils.UNK_ID, input)

                input_embedding = tf.nn.embedding_lookup(input_embeddings, input)

                if self.use_copynet:
                    if i == 0:
                        attn_dim = attention_states.get_shape()[2]
                        selective_reads = tf.zeros([self.batch_size, attn_dim])
                        if bs_decoding:
                            selective_reads = beam_decoder.wrap_input(selective_reads)
                    else:
                        selective_reads = attns[-1] * read_copy_source
                    input_embedding = tf.concat(axis=1, values=[input_embedding, selective_reads])
                    output, state, alignments, attns, read_copy_source = \
                        decoder_cell(input_embedding, state)
                    alignments_list.append(alignments)
                elif self.use_attention:
                    output, state, alignments, attns = \
                        decoder_cell(input_embedding, state)
                    alignments_list.append(alignments)
                else:
                    output, state = decoder_cell(input_embedding, state)
               
                # record output state to compute the loss.
                if bs_decoding:
                    # when doing beam search decoding, the output state of each
                    # step cannot simply be gathered step-wise outside the decoder
                    # (speical case: beam_size = 1)
                    pass
                else:
                    outputs.append(output)
                    states.append(state)

            if self.use_attention:
                # Tensor list --> tenosr
                attn_alignments = tf.concat(axis=1,
                    values=[tf.expand_dims(x[0], 1) for x in alignments_list])
            if self.use_copynet:
                pointers = tf.concat(axis=1,
                    values=[tf.expand_dims(x[1], 1) for x in alignments_list])
            else:
                pointers = None

            if bs_decoding:
                # Beam-search output
                (
                    past_cand_symbols,  # [batch_size, max_len]
                    past_cand_logprobs, # [batch_size]
                    past_beam_symbols,  # [batch_size*self.beam_size, max_len], right-aligned!!!
                    past_beam_logprobs, # [batch_size*self.beam_size]
                    past_cell_states,
                ) = state
                # [self.batch_size, self.beam_size, max_len]
                top_k_outputs = tf.reshape(past_beam_symbols[:, 1:],
                                           [self.batch_size, self.beam_size, -1])
                top_k_outputs = tf.split(axis=0, num_or_size_splits=self.batch_size, value=top_k_outputs)
                top_k_outputs = [tf.split(axis=0, num_or_size_splits=self.beam_size,
                                          value=tf.squeeze(top_k_output, axis=[0]))
                                 for top_k_output in top_k_outputs]
                top_k_outputs = [[tf.squeeze(output, axis=[0]) for output in top_k_output]
                                 for top_k_output in top_k_outputs]
                # [self.batch_size, self.beam_size]
                top_k_logits = tf.reshape(past_beam_logprobs, [self.batch_size, self.beam_size])
                top_k_logits = tf.split(axis=0, num_or_size_splits=self.batch_size, value=top_k_logits)
                top_k_logits = [tf.squeeze(top_k_logit, axis=[0])
                                for top_k_logit in top_k_logits]
                if self.use_attention:
                    attn_alignments = tf.reshape(attn_alignments,
                            [self.batch_size, self.beam_size, len(decoder_inputs),
                             attention_states.get_shape()[1].value])
                # LSTM: ([batch_size*self.beam_size, :, dim],
                #        [batch_size*self.beam_size, :, dim])
                # GRU: [batch_size*self.beam_size, :, dim]
                if self.rnn_cell == 'lstm':
                    if self.num_layers == 1:
                        c_states, h_states = past_cell_states
                        states = list(zip(
                            [tf.squeeze(x, axis=[1])
                             for x in tf.split(axis=1, num_or_size_splits=c_states.get_shape()[1], value=c_states)],
                            [tf.squeeze(x, axis=[1])
                             for x in tf.split(axis=1, num_or_size_splits=h_states.get_shape()[1], value=h_states)]))
                    else:
                        layered_states = [list(zip(
                                [tf.squeeze(x, axis=[1]) 
                                    for x in tf.split(axis=1, num_or_size_splits=c_states.get_shape()[1], value=c_states)[1:]],
                                [tf.squeeze(x, axis=[1])
                                    for x in tf.split(axis=1, num_or_size_splits=h_states.get_shape()[1], value=h_states)[1:]]))
                            for c_states, h_states in past_cell_states]
                        states = list(zip(layered_states))
                elif self.rnn_cell in ['gru', 'ran']:
                    states = [tf.squeeze(x, axis=[1]) for x in \
                        tf.split(num_or_size_splits=past_cell_states.get_shape()[1],
                                 axis=1, value=past_cell_states)][1:]
                else:
                    raise AttributeError(
                        "Unrecognized rnn cell type: {}".format(self.rnn_cell))

                # TODO: correct beam search output logits computation
                # so far dummy zero vectors are used
                if self.use_copy and self.copy_fun != 'supervised':
                    outputs = [tf.zeros([self.batch_size * self.beam_size, self.vocab_size])
                           for s in states]
                else:
                    outputs = [tf.zeros([self.batch_size * self.beam_size, self.dim])
                           for s in states]
                return top_k_outputs, top_k_logits, outputs, states, attn_alignments, pointers
            else:
                # Greedy output
                if self.use_copynet:
                    epsilon = tf.constant(1e-12)
                    projected_output = tf.log(output + epsilon)
                else:
                    W, b = self.output_project
                    projected_output = tf.nn.log_softmax(tf.matmul(output, W) + b)
                output_symbol = tf.argmax(projected_output, 1)
                past_output_symbols.append(tf.expand_dims(output_symbol, 1))
                output_symbols = tf.concat(axis=1, values=past_output_symbols) \
                    if self.forward_only else tf.cast(input, tf.float32)
                past_output_logits = tf.add(
                    past_output_logits, tf.reduce_max(projected_output, 1))
                return output_symbols, past_output_logits, outputs, states, \
                    attn_alignments, pointers