コード例 #1
0
def encoder(src_embedding, src_sequence_length):
    #  src_embedding:[batch_size, sequence_length, ...]
    #  双向GRU编码器

    # 使用GRUCell构建前向RNN
    encoder_fwd_cell = layers.GRUCell(hidden_size=hidden_dim)
    encoder_fwd_output, fwd_state = layers.rnn(
        cell=encoder_fwd_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        time_major=False,
        is_reverse=False)

    # 使用GRUCell构建反向RNN
    encoder_bwd_cell = layers.GRUCell(hidden_size=hidden_dim)
    encoder_bwd_output, bwd_state = layers.rnn(
        cell=encoder_bwd_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        time_major=False,
        is_reverse=True)

    # 拼接前向与反向GRU的编码结果得到h
    encoder_output = layers.concat(
        input=[encoder_fwd_output, encoder_bwd_output], axis=2)
    encoder_state = layers.concat(input=[fwd_state, bwd_state], axis=1)
    return encoder_output, encoder_state
コード例 #2
0
    def forward(self, inputs, input_lens):
        """run bi-rnn

        Args:
            inputs (TYPE): NULL
            input_lens (TYPE): NULL

        Returns: TODO

        Raises: NULL

        """
        fwd_output, fwd_final_state = layers.rnn(self.fwd_cell,
                                                 inputs,
                                                 sequence_length=input_lens)
        if self._bidirectional:
            bwd_output, bwd_final_state = layers.rnn(
                self.bwd_cell,
                inputs,
                sequence_length=input_lens,
                is_reverse=True)

            output = layers.concat(input=[fwd_output, bwd_output], axis=-1)
            final_state = [
                layers.concat(input=[fwd_final_state[0], bwd_final_state[0]],
                              axis=-1),
                layers.concat(input=[fwd_final_state[1], bwd_final_state[1]],
                              axis=-1),
            ]
        else:
            output = fwd_output
            final_state = fwd_final_state

        return output, final_state
コード例 #3
0
 def __call__(self, src_emb, src_sequence_length):
     encoder_output, encoder_final_state = layers.rnn(
         cell=self.encoder_cell,
         inputs=src_emb,
         sequence_length=src_sequence_length,
         is_reverse=False)
     return encoder_output, encoder_final_state
コード例 #4
0
    def _build_decoder(self, enc_final_state, mode='train', beam_size=10):

        dec_cell = DecoderCell(self.num_layers, self.hidden_size, self.dropout,
                               self.init_scale)
        output_layer = lambda x: layers.fc(x,
                                           size=self.tar_vocab_size,
                                           num_flatten_dims=len(x.shape) - 1,
                                           param_attr=fluid.ParamAttr(
                                               name="output_w",
                                               initializer=uniform_initializer(
                                                   self.init_scale)),
                                           bias_attr=False)

        if mode == 'train':
            dec_output, dec_final_state = rnn(cell=dec_cell,
                                              inputs=self.tar_emb,
                                              initial_states=enc_final_state)

            dec_output = output_layer(dec_output)

            return dec_output
        elif mode == 'beam_search':
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                self.beam_start_token,
                self.beam_end_token,
                beam_size,
                embedding_fn=self.tar_embeder,
                output_fn=output_layer)

            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=enc_final_state,
                                        max_step_num=self.beam_max_step_num)
            return outputs
コード例 #5
0
    def _build_decoder(self, enc_final_state, mode='train', beam_size=10):
        output_layer = lambda x: layers.fc(
            x,
            size=self.tar_vocab_size,
            num_flatten_dims=len(x.shape) - 1,
            param_attr=fluid.ParamAttr(
                name="output_w",
                initializer=fluid.initializer.UniformInitializer(
                    low=-self.init_scale, high=self.init_scale)),
            bias_attr=False)

        dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size,
                                        self.dropout, self.init_scale)
        dec_initial_states = [
            enc_final_state,
            dec_cell.get_initial_states(batch_ref=self.enc_output,
                                        shape=[self.hidden_size])
        ]
        max_src_seq_len = layers.shape(self.src)[1]
        src_mask = layers.sequence_mask(self.src_sequence_length,
                                        maxlen=max_src_seq_len,
                                        dtype='float32')
        enc_padding_mask = (src_mask - 1.0)
        if mode == 'train':
            dec_output, _ = rnn(cell=dec_cell,
                                inputs=self.tar_emb,
                                initial_states=dec_initial_states,
                                sequence_length=None,
                                enc_output=self.enc_output,
                                enc_padding_mask=enc_padding_mask)

            dec_output = output_layer(dec_output)

        elif mode == 'beam_search':
            output_layer = lambda x: layers.fc(
                x,
                size=self.tar_vocab_size,
                num_flatten_dims=len(x.shape) - 1,
                param_attr=fluid.ParamAttr(name="output_w"),
                bias_attr=False)
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                self.beam_start_token,
                self.beam_end_token,
                beam_size,
                embedding_fn=self.tar_embeder,
                output_fn=output_layer)
            enc_output = beam_search_decoder.tile_beam_merge_with_batch(
                self.enc_output, beam_size)
            enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(
                enc_padding_mask, beam_size)
            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=dec_initial_states,
                                        max_step_num=self.beam_max_step_num,
                                        enc_output=enc_output,
                                        enc_padding_mask=enc_padding_mask)
            return outputs

        return dec_output
コード例 #6
0
 def _build_encoder(self):
     enc_cell = EncoderCell(self.num_layers, self.hidden_size, self.dropout,
                            self.init_scale)
     self.enc_output, enc_final_state = rnn(
         cell=enc_cell,
         inputs=self.src_emb,
         sequence_length=self.src_sequence_length)
     return self.enc_output, enc_final_state
コード例 #7
0
def decoder(encoder_output,
            encoder_output_proj,
            encoder_state,
            encoder_padding_mask,
            trg=None,
            is_train=True):
    """Decoder: GRU with Attention"""
    decoder_cell = DecoderCell(hidden_size=decoder_size)
    decoder_initial_states = layers.fc(encoder_state,
                                       size=decoder_size,
                                       act="tanh")
    trg_embeder = lambda x: fluid.embedding(
        input=x,
        size=[target_dict_size, hidden_dim],
        dtype="float32",
        param_attr=fluid.ParamAttr(name="trg_emb_table"))
    output_layer = lambda x: layers.fc(x,
                                       size=target_dict_size,
                                       num_flatten_dims=len(x.shape) - 1,
                                       param_attr=fluid.ParamAttr(name=
                                                                  "output_w"))
    if is_train:
        decoder_output, _ = layers.rnn(
            cell=decoder_cell,
            inputs=trg_embeder(trg),
            initial_states=decoder_initial_states,
            time_major=False,
            encoder_output=encoder_output,
            encoder_output_proj=encoder_output_proj,
            encoder_padding_mask=encoder_padding_mask)
        decoder_output = output_layer(decoder_output)
    else:
        encoder_output = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output, beam_size)
        encoder_output_proj = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output_proj, beam_size)
        encoder_padding_mask = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_padding_mask, beam_size)
        beam_search_decoder = layers.BeamSearchDecoder(
            cell=decoder_cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=trg_embeder,
            output_fn=output_layer)
        decoder_output, _ = layers.dynamic_decode(
            decoder=beam_search_decoder,
            inits=decoder_initial_states,
            max_step_num=max_length,
            output_time_major=False,
            encoder_output=encoder_output,
            encoder_output_proj=encoder_output_proj,
            encoder_padding_mask=encoder_padding_mask)

    return decoder_output
コード例 #8
0
def encoder(src_embedding, src_sequence_length):
    """Encoder: Bidirectional GRU"""
    encoder_fwd_cell = layers.GRUCell(hidden_size=hidden_dim)
    encoder_fwd_output, fwd_state = layers.rnn(
        cell=encoder_fwd_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        time_major=False,
        is_reverse=False)
    encoder_bwd_cell = layers.GRUCell(hidden_size=hidden_dim)
    encoder_bwd_output, bwd_state = layers.rnn(
        cell=encoder_bwd_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        time_major=False,
        is_reverse=True)
    encoder_output = layers.concat(
        input=[encoder_fwd_output, encoder_bwd_output], axis=2)
    encoder_state = layers.concat(input=[fwd_state, bwd_state], axis=1)
    return encoder_output, encoder_state
コード例 #9
0
ファイル: lm_model.py プロジェクト: guoshengCS/rnn-benchmark
    def seq2seq_api_rnn(input_embedding,
                        len=3,
                        init_hiddens=None,
                        init_cells=None):
        class EncoderCell(layers.RNNCell):
            def __init__(self,
                         num_layers,
                         hidden_size,
                         dropout_prob=0.,
                         forget_bias=0.):
                self.num_layers = num_layers
                self.hidden_size = hidden_size
                self.dropout_prob = dropout_prob
                self.lstm_cells = []
                for i in range(num_layers):
                    self.lstm_cells.append(
                        layers.LSTMCell(
                            hidden_size,
                            forget_bias=forget_bias,
                            param_attr=fluid.ParamAttr(
                                initializer=fluid.initializer.
                                UniformInitializer(low=-init_scale,
                                                   high=init_scale))))

            def call(self, step_input, states):
                new_states = []
                for i in range(self.num_layers):
                    out, new_state = self.lstm_cells[i](step_input, states[i])
                    step_input = layers.dropout(
                        out,
                        self.dropout_prob,
                        dropout_implementation='upscale_in_train'
                    ) if self.dropout_prob > 0 else out
                    new_states.append(new_state)
                return step_input, new_states

        cell = EncoderCell(num_layers, hidden_size, dropout)
        output, new_states = layers.rnn(
            cell,
            inputs=input_embedding,
            initial_states=[[hidden, cell] for hidden, cell in zip([
                layers.reshape(init_hidden, shape=[-1, hidden_size])
                for init_hidden in layers.split(
                    init_hiddens, num_or_sections=num_layers, dim=0)
            ], [
                layers.reshape(init_cell, shape=[-1, hidden_size])
                for init_cell in layers.split(
                    init_cells, num_or_sections=num_layers, dim=0)
            ])],
            time_major=False)
        last_hidden = layers.stack([hidden for hidden, _ in new_states], 0)
        last_cell = layers.stack([cell for _, cell in new_states], 0)
        return output, last_hidden, last_cell
コード例 #10
0
ファイル: model.py プロジェクト: wbj0110/models
 def _build_encoder(self):
     self.enc_input = layers.dropout(
         self.src_emb,
         dropout_prob=self.enc_dropout_in,
         dropout_implementation="upscale_in_train")
     enc_cell = EncoderCell(self.num_layers, self.hidden_size,
                            self.param_attr_initializer,
                            self.param_attr_scale, self.enc_dropout_out)
     enc_output, enc_final_state = rnn(
         cell=enc_cell,
         inputs=self.enc_input,
         sequence_length=self.src_sequence_length)
     return enc_output, enc_final_state
コード例 #11
0
ファイル: generation_task.py プロジェクト: wuhuaha/beike
    def _build_net(self):
        self.seq_len = fluid.layers.data(name="seq_len",
                                         shape=[1],
                                         dtype='int64',
                                         lod_level=0)
        self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
        src_mask = fluid.layers.sequence_mask(self.seq_len_used,
                                              maxlen=self.max_seq_len,
                                              dtype='float32')
        enc_padding_mask = (src_mask - 1.0)

        # Define decoder and initialize it.
        dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size,
                                        self.dropout)
        dec_init_hidden = fluid.layers.fc(
            input=self.feature,
            size=self.hidden_size,
            num_flatten_dims=1,
            param_attr=fluid.ParamAttr(
                name="dec_init_hidden_w",
                initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
            bias_attr=fluid.ParamAttr(
                name="dec_init_hidden_b",
                initializer=fluid.initializer.Constant(0.)))
        dec_initial_states = [[[
            dec_init_hidden,
            dec_cell.get_initial_states(batch_ref=self.feature,
                                        shape=[self.hidden_size])
        ]] * self.num_layers,
                              dec_cell.get_initial_states(
                                  batch_ref=self.feature,
                                  shape=[self.hidden_size])]
        tar_vocab_size = len(self._label_list)
        tar_embeder = lambda x: fluid.embedding(
            input=x,
            size=[tar_vocab_size, self.hidden_size],
            dtype='float32',
            is_sparse=False,
            param_attr=fluid.ParamAttr(name='target_embedding',
                                       initializer=fluid.initializer.
                                       UniformInitializer(low=-0.1, high=0.1)))
        start_token_id = self._label_list.index(self.start_token)
        end_token_id = self._label_list.index(self.end_token)
        if not self.is_predict_phase:
            self.dec_input = fluid.layers.data(name="dec_input",
                                               shape=[self.max_seq_len],
                                               dtype='int64')
            tar_emb = tar_embeder(self.dec_input)
            dec_output, _ = rnn(cell=dec_cell,
                                inputs=tar_emb,
                                initial_states=dec_initial_states,
                                sequence_length=None,
                                enc_output=self.token_feature,
                                enc_padding_mask=enc_padding_mask)
            self.logits = fluid.layers.fc(
                dec_output,
                size=tar_vocab_size,
                num_flatten_dims=len(dec_output.shape) - 1,
                param_attr=fluid.ParamAttr(
                    name="output_w",
                    initializer=fluid.initializer.UniformInitializer(
                        low=-0.1, high=0.1)))
            self.ret_infers = fluid.layers.reshape(x=fluid.layers.argmax(
                self.logits, axis=2),
                                                   shape=[-1, 1])
            logits = self.logits
            logits = fluid.layers.softmax(logits)
            return [logits]
        else:
            output_layer = lambda x: fluid.layers.fc(
                x,
                size=tar_vocab_size,
                num_flatten_dims=len(x.shape) - 1,
                param_attr=fluid.ParamAttr(name="output_w"))
            beam_search_decoder = BeamSearchDecoder(dec_cell,
                                                    start_token_id,
                                                    end_token_id,
                                                    self.beam_size,
                                                    embedding_fn=tar_embeder,
                                                    output_fn=output_layer)
            enc_output = beam_search_decoder.tile_beam_merge_with_batch(
                self.token_feature, self.beam_size)
            enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(
                enc_padding_mask, self.beam_size)
            self.ret_infers, _ = dynamic_decode(
                beam_search_decoder,
                inits=dec_initial_states,
                max_step_num=self.beam_max_step_num,
                enc_output=enc_output,
                enc_padding_mask=enc_padding_mask)
            return self.ret_infers
コード例 #12
0
ファイル: model.py プロジェクト: wbj0110/models
    def _build_decoder(self,
                       z_mean=None,
                       z_log_var=None,
                       enc_output=None,
                       mode='train',
                       beam_size=10):
        dec_input = layers.dropout(self.tar_emb,
                                   dropout_prob=self.dec_dropout_in,
                                   dropout_implementation="upscale_in_train")

        # `output_layer` will be used within BeamSearchDecoder
        output_layer = lambda x: layers.fc(x,
                                           size=self.tar_vocab_size,
                                           num_flatten_dims=len(x.shape) - 1,
                                           name="output_w")

        # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
        # it will be used within BeamSearchDecoder
        sample_output_layer = lambda x: layers.unsqueeze(
            fluid.one_hot(layers.unsqueeze(
                layers.sampling_id(layers.softmax(
                    layers.squeeze(output_layer(x), [1])),
                                   dtype='int'), [1]),
                          depth=self.tar_vocab_size), [1])

        if mode == 'train':
            latent_z = self._sampling(z_mean, z_log_var)
        else:
            latent_z = layers.gaussian_random_batch_size_like(
                self.tar, shape=[-1, self.latent_size])
        dec_first_hidden_cell = layers.fc(latent_z,
                                          2 * self.hidden_size *
                                          self.num_layers,
                                          name='fc_hc')
        dec_first_hidden, dec_first_cell = layers.split(
            dec_first_hidden_cell, 2)
        if self.num_layers > 1:
            dec_first_hidden = layers.split(dec_first_hidden, self.num_layers)
            dec_first_cell = layers.split(dec_first_cell, self.num_layers)
        else:
            dec_first_hidden = [dec_first_hidden]
            dec_first_cell = [dec_first_cell]
        dec_initial_states = [[h, c]
                              for h, c in zip(dec_first_hidden, dec_first_cell)
                              ]
        dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z,
                               self.param_attr_initializer,
                               self.param_attr_scale, self.dec_dropout_out)

        if mode == 'train':
            dec_output, _ = rnn(cell=dec_cell,
                                inputs=dec_input,
                                initial_states=dec_initial_states,
                                sequence_length=self.tar_sequence_length)
            dec_output = output_layer(dec_output)

            return dec_output
        elif mode == 'greedy':
            start_token = 1
            end_token = 2
            max_length = 100
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                start_token,
                end_token,
                beam_size=1,
                embedding_fn=self.tar_embeder,
                output_fn=output_layer)
            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=dec_initial_states,
                                        max_step_num=max_length)
            return outputs

        elif mode == 'sampling':
            start_token = 1
            end_token = 2
            max_length = 100
            beam_search_decoder = BeamSearchDecoder(
                dec_cell,
                start_token,
                end_token,
                beam_size=1,
                embedding_fn=self.tar_embeder,
                output_fn=sample_output_layer)

            outputs, _ = dynamic_decode(beam_search_decoder,
                                        inits=dec_initial_states,
                                        max_step_num=max_length)
            return outputs
        else:
            print("mode not supprt", mode)
コード例 #13
0
def decoder(encoder_output,
            encoder_output_proj,
            encoder_state,
            encoder_padding_mask,
            trg=None,
            is_train=True):
    # 定义 RNN 所需要的组件
    decoder_cell = DecoderCell(hidden_size=decoder_size)
    decoder_initial_states = layers.fc(encoder_state,
                                       size=decoder_size,
                                       act="tanh")
    trg_embeder = lambda x: fluid.embedding(
        input=x,
        size=[target_dict_size, hidden_dim],
        dtype="float32",
        param_attr=fluid.ParamAttr(name="trg_emb_table"))

    output_layer = lambda x: layers.fc(input=x,
                                       size=target_dict_size,
                                       num_flatten_dims=len(x.shape) - 1,
                                       param_attr=fluid.ParamAttr(name=
                                                                  "output_w"))
    if is_train:  # 训练
        # 训练时使用 `layers.rnn` 构造由 `cell` 指定的循环神经网络
        # 循环的每一步从 `inputs` 中切片产生输入,并执行 `cell.call`
        # [-1,-1,512,] , [-1,512,]
        decoder_output, _ = layers.rnn(
            cell=decoder_cell,
            inputs=trg_embeder(trg),
            initial_states=decoder_initial_states,
            time_major=False,
            encoder_output=encoder_output,
            encoder_output_proj=encoder_output_proj,
            encoder_padding_mask=encoder_padding_mask)

        decoder_output = layers.fc(input=decoder_output,
                                   size=target_dict_size,
                                   num_flatten_dims=2,
                                   param_attr=fluid.ParamAttr(name="output_w"))

    else:  # 基于 beam search 的预测生成
        # beam search 时需要将用到的形为 `[batch_size, ...]` 的张量扩展为 `[batch_size* beam_size, ...]`
        encoder_output = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output, beam_size)
        encoder_output_proj = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output_proj, beam_size)
        encoder_padding_mask = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_padding_mask, beam_size)
        # BeamSearchDecoder 定义了单步解码的操作:`cell.call` + `beam_search_step`
        beam_search_decoder = layers.BeamSearchDecoder(
            cell=decoder_cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=trg_embeder,
            output_fn=output_layer)
        # 使用 layers.dynamic_decode 动态解码
        # 重复执行 `decoder.step()` 直到其返回的表示完成状态的张量中的值全部为True或解码步骤达到 `max_step_num`
        decoder_output, _ = layers.dynamic_decode(
            decoder=beam_search_decoder,
            inits=decoder_initial_states,
            max_step_num=max_length,
            output_time_major=False,
            encoder_output=encoder_output,
            encoder_output_proj=encoder_output_proj,
            encoder_padding_mask=encoder_padding_mask)

    return decoder_output
コード例 #14
0
    def forward(self,
                inputs,
                initial_states=None,
                sequence_length=None,
                **kwargs):
        if F.in_dygraph_mode():

            class OutputArray(object):
                def __init__(self, x):
                    self.array = [x]

                def append(self, x):
                    self.array.append(x)

            def _maybe_copy(state, new_state, step_mask):
                # TODO: use where_op
                new_state = L.elementwise_mul(new_state, step_mask, axis=0) - \
                        L.elementwise_mul(state, (step_mask - 1), axis=0)
                return new_state

            #logging.info("inputs shape: {}".format(inputs.shape))
            flat_inputs = U.flatten(inputs)
            #logging.info("flat inputs len: {}".format(len(flat_inputs)))
            #logging.info("flat inputs[0] shape: {}".format(flat_inputs[0].shape))

            batch_size, time_steps = (
                flat_inputs[0].shape[self.batch_index],
                flat_inputs[0].shape[self.time_step_index])
            #logging.info("batch_size: {}".format(batch_size))
            #logging.info("time_steps: {}".format(time_steps))

            if initial_states is None:
                initial_states = self.cell.get_initial_states(
                    batch_ref=inputs, batch_dim_idx=self.batch_index)

            if not self.time_major:
                # 如果第一维不是时间步 则第一维和第二维交换
                # 第一维为时间步
                inputs = U.map_structure(
                    lambda x: L.transpose(x, [1, 0] + list(
                        range(2, len(x.shape)))), inputs)

            if sequence_length is not None:
                mask = L.sequence_mask(
                    sequence_length,
                    maxlen=time_steps,
                    dtype=U.flatten(initial_states)[0].dtype)
                # 同样 第一维为时间步
                mask = L.transpose(mask, [1, 0])

            if self.is_reverse:
                # 如果反向
                # 则第一维反向
                inputs = U.map_structure(lambda x: L.reverse(x, axis=[0]), inputs)
                mask = L.reverse(mask, axis=[0]) if sequence_length is not None else None

            states = initial_states
            outputs = []
            # 遍历时间步
            for i in range(time_steps):
                # 取该时间步的输入
                step_inputs = U.map_structure(lambda x: x[i], inputs)
                # 输入当前输入和状态
                # 得到输出和新状态
                step_outputs, new_states = self.cell(step_inputs, states, **kwargs)
                if sequence_length is not None:
                    # 如果有mask 则被mask的地方 用原state的数
                    # _maybe_copy: 未mask的部分用new_states, 被mask的部分用states
                    new_states = U.map_structure(
                        partial(_maybe_copy, step_mask=mask[i]),
                        states,
                        new_states)
                states = new_states
                #logging.info("step_output shape: {}".format(step_outputs.shape))

                if i == 0:
                    # 初始时,各输出
                    outputs = U.map_structure(lambda x: OutputArray(x), step_outputs)
                else:
                    # 各输出加入对应list中
                    U.map_structure(lambda x, x_array: x_array.append(x), step_outputs, outputs)

            # 最后按时间步的维度堆叠
            final_outputs = U.map_structure(
                lambda x: L.stack(x.array, axis=self.time_step_index),
                outputs)
            #logging.info("final_outputs shape: {}".format(final_outputs.shape))

            if self.is_reverse:
                # 如果是反向 则最后结果也反向一下
                final_outputs = U.map_structure(
                    lambda x: L.reverse(x, axis=self.time_step_index),
                    final_outputs)

            final_states = new_states

        else:
            final_outputs, final_states = L.rnn(
                self.cell,
                inputs,
                initial_states=initial_states,
                sequence_length=sequence_length,
                time_major=self.time_major,
                is_reverse=self.is_reverse,
                **kwargs)

        return final_outputs, final_states