Пример #1
0
 def __init__(self, units):
     super(MyGRU, self).__init__()
     self.state0 = [tf.zeros(shape=(batch_size, units))]
     self.state1 = [tf.zeros(shape=(batch_size, units))]
     self.embedding = layers.Embedding(num_words, 100, input_length=seq_len)
     self.cell0 = layers.GRUCell(units, dropout=0.5)
     self.cell1 = layers.GRUCell(units, dropout=0.5)
     self.outlayer = layers.Dense(1)
Пример #2
0
 def __init__(self, units):
     super(MyRNN, self).__init__()
     # [b, 64]
     self.state0 = [tf.zeros([bs, units])]
     self.state1 = [tf.zeros([bs, units])]
     # [b, 80] -> [b, 80, 100]
     self.embedding = layers.Embedding(total_words,
                                       embedding_len,
                                       input_length=max_len)
     self.rnn_cell0 = layers.GRUCell(units, dropout=0.5)
     self.rnn_cell1 = layers.GRUCell(units, dropout=0.5)
     # [b, 80, 100] -> [b, 64] -> [b, 1]
     self.out_layer = layers.Dense(1)
Пример #3
0
 def __init__(self, units):
     super(MyLSTM, self).__init__()
     # 初始化[h0]
     self.state0 = [tf.zeros([batchSize, units])]
     self.state1 = [tf.zeros([batchSize, units])]
     # [b,200] => [b,200,100]
     self.embedding = layers.Embedding(input_dim=total_words,
                                       input_length=max_review_len,
                                       output_dim=embedding_len)
     self.gru_cell0 = layers.GRUCell(units=units, dropout=0.5)
     self.gru_cell1 = layers.GRUCell(units=units, dropout=0.5)
     self.outlayer1 = layers.Dense(32)
     self.outlayer2 = layers.Dense(1)
Пример #4
0
 def __init__(
     self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1,
     rec_depth=1, shared=False, discrete=False, act=tf.nn.elu,
     mean_act='none', std_act='softplus', temp_post=True, min_std=0.1,
     cell='keras'):
   super().__init__()
   self._stoch = stoch
   self._deter = deter
   self._hidden = hidden
   self._min_std = min_std
   self._layers_input = layers_input
   self._layers_output = layers_output
   self._rec_depth = rec_depth
   self._shared = shared
   self._discrete = discrete
   self._act = act
   self._mean_act = mean_act
   self._std_act = std_act
   self._temp_post = temp_post
   self._embed = None
   if cell == 'gru':
     self._cell = tfkl.GRUCell(self._deter)
   elif cell == 'gru_layer_norm':
     self._cell = GRUCell(self._deter, norm=True)
   else:
     raise NotImplementedError(cell)
Пример #5
0
    def __init__(self, action_size, img_dim=32,
                    policy_clip=0.2, value_clip=0.2,
                    entropy_beta=0.01, val_discount=1.0, **kargs):

        super(ppo_rnn, self).__init__()

        ''' Hyperparameters '''
        self.action_size  = action_size
        self.value_size   = 1
        self.policy_clip  = policy_clip
        self.value_clip   = value_clip
        self.entropy_beta = entropy_beta
        self.val_discount = val_discount


        ''' Networks '''
        self.embed_fn = keras.Sequential()
        self.embed_fn.add(layers.Conv2D(32, 8, 4, padding='same', activation=Mish(), kernel_initializer='lecun_normal'))
        self.embed_fn.add(layers.Conv2D(64, 4, 2, padding='same', activation=Mish(), kernel_initializer='lecun_normal'))
        self.embed_fn.add(layers.Conv2D(64, 3, 1, padding='same', activation=Mish(), kernel_initializer='lecun_normal'))
        self.embed_fn.add(layers.Flatten())
        self.embed_fn.add(layers.Dense(512, activation=Mish(), kernel_initializer='lecun_normal'))
        # self.embed_fn.add(layers.Dense(128, activation=Mish(), kernel_initializer='lecun_normal'))

        self.rnn_cell = layers.GRUCell(512)

        self.policy_fn = keras.Sequential()
        # self.policy_fn.add(layers.Dense(128, activation=Mish(), kernel_initializer='lecun_normal'))
        self.policy_fn.add(layers.Dense(self.action_size, activation='linear', kernel_initializer='lecun_normal'))

        self.value_fn = keras.Sequential()
        # self.value_fn.add(layers.Dense(128, activation=Mish(), kernel_initializer='lecun_normal'))
        self.value_fn.add(layers.Dense(self.value_size, activation='linear', kernel_initializer='lecun_normal'))
Пример #6
0
 def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu):
     super().__init__()
     self._activation = act
     self._stoch_size = stoch
     self._deter_size = deter
     self._hidden_size = hidden
     self._cell = tfkl.GRUCell(self._deter_size)
 def __init__(self, units):
     super(MyRNN, self).__init__()
     # [b, 64],构建Cell初始化状态向量,重复使用
     self.state0 = [tf.zeros([batchsz, units])]
     self.state1 = [tf.zeros([batchsz, units])]
     # 词向量编码 [b, 80] => [b, 80, 100]
     self.embedding = layers.Embedding(total_words, embedding_len, input_length=max_review_len)
     # 构建2个Cell
     self.rnn_cell0 = layers.GRUCell(units, dropout=0.5)
     self.rnn_cell1 = layers.GRUCell(units, dropout=0.5)
     # 构建分类网络,用于将CELL的输出特征进行分类,2分类
     # [b, 80, 100] => [b, 64] => [b, 1]
     self.outlayer = Sequential([
         layers.Dense(units),
         layers.Dropout(rate=0.5),
         layers.ReLU(),
         layers.Dense(1)])
Пример #8
0
def get_cell(cell_type, units):
    cell_type = cell_type.lower()
    if cell_type == 'lstm':
        return layers.LSTMCell(units=units)
    if cell_type == 'gru':
        return layers.GRUCell(units=units)
    if cell_type == 'simple':
        return RNNCell(units=units)
    raise ValueError('Unknown RNN cell type')
Пример #9
0
    def __init__(self, config):
        super(DecoderAtt, self).__init__(name="trajectory_decoder")
        self.add_social     = config.add_social
        self.stack_rnn_size = config.stack_rnn_size
        self.rnn_type       = config.rnn_type
        # Linear embedding of the encoding resulting observed trajectories
        self.traj_xy_emb_dec = layers.Dense(config.emb_size,
            activation=config.activation_func,
            name='trajectory_position_embedding')
        # RNN cell
        # Condition for cell type
        if self.rnn_type == 'gru':
            # GRU cell
            self.dec_cell_traj = layers.GRUCell(config.dec_hidden_size,
                                                recurrent_initializer='glorot_uniform',
                                                dropout=config.dropout_rate,
                                                recurrent_dropout=config.dropout_rate,
                                                name='trajectory_decoder_GRU_cell')
        else:
            # LSTM cell
            self.dec_cell_traj = layers.LSTMCell(config.dec_hidden_size,
                                                recurrent_initializer='glorot_uniform',
                                                name='trajectory_decoder_LSTM_cell',
                                                dropout=config.dropout_rate,
                                                recurrent_dropout=config.dropout_rate)
        # RNN layer
        self.recurrentLayer = layers.RNN(self.dec_cell_traj,return_sequences=True,return_state=True)
        self.M = 1
        if (self.add_social):
            self.M=self.M+1

        # Attention layer
        self.focal_attention = FocalAttention(config,self.M)
        # Dropout layer
        self.dropout = layers.Dropout(config.dropout_rate,name="dropout_decoder_h")
        # Mapping from h to positions
        self.h_to_xy = layers.Dense(config.P,
            activation=tf.identity,
            name='h_to_xy')

        # Input layers
        # Position input
        dec_input_shape      = (1,config.P)
        self.input_layer_pos = layers.Input(dec_input_shape,name="position")
        enc_last_state_shape = (config.dec_hidden_size)
        # Proposals for inital states
        self.input_layer_hid1= layers.Input(enc_last_state_shape,name="initial_state_h")
        self.input_layer_hid2= layers.Input(enc_last_state_shape,name="initial_state_c")
        # Context shape: [N,M,T1,h_dim]
        ctxt_shape = (self.M,config.obs_len,config.enc_hidden_size)
        # Context input
        self.input_layer_ctxt = layers.Input(ctxt_shape,name="context")
        self.out = self.call((self.input_layer_pos,(self.input_layer_hid1,self.input_layer_hid2),self.input_layer_ctxt))
        # Call init again. This is a workaround for being able to use summary
        super(DecoderAtt, self).__init__(
                    inputs= [self.input_layer_pos,self.input_layer_hid1,self.input_layer_hid2,self.input_layer_ctxt],
                    outputs=self.out)
 def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
     super(Seq2SeqAttentionDecoder, self).__init__()
     self.attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
     self.embed = layers.Embedding(input_dim=vocab_size, output_dim=embed_size)
     self.rnn = layers.RNN(
         layers.StackedRNNCells([layers.GRUCell(units=num_hiddens, dropout=dropout) for _ in range(num_layers)])
         , return_state=True
         , return_sequences=True
     )
     self.dense = layers.Dense(units=vocab_size)
Пример #11
0
    def __init__(self, config):
        super(DecoderOf, self).__init__(name="trajectory_decoder")
        self.rnn_type = config.rnn_type
        # Linear embedding of the encoding resulting observed trajectories
        self.traj_xy_emb_dec = layers.Dense(
            config.emb_size,
            activation=config.activation_func,
            name='trajectory_position_embedding')
        # RNN cell
        # Condition for cell type
        if self.rnn_type == 'gru':
            # GRU cell
            self.dec_cell_traj = layers.GRUCell(
                config.dec_hidden_size,
                recurrent_initializer='glorot_uniform',
                dropout=config.dropout_rate,
                recurrent_dropout=config.dropout_rate,
                name='trajectory_decoder_cell_with_GRU')
        else:
            # LSTM cell
            self.dec_cell_traj = layers.LSTMCell(
                config.dec_hidden_size,
                recurrent_initializer='glorot_uniform',
                name='trajectory_decoder_cell_with_LSTM',
                dropout=config.dropout_rate,
                recurrent_dropout=config.dropout_rate)
        # RNN layer
        self.recurrentLayer = layers.RNN(self.dec_cell_traj,
                                         return_sequences=True,
                                         return_state=True)
        # Dropout layer
        self.dropout = layers.Dropout(config.dropout_rate,
                                      name="dropout_decoder_h")
        # Mapping from h to positions
        self.h_to_xy = layers.Dense(config.P,
                                    activation=tf.identity,
                                    name='h_to_xy')

        # Input layers
        # Position input
        dec_input_shape = (1, config.P)
        self.input_layer_pos = layers.Input(dec_input_shape, name="position")
        enc_last_state_shape = (config.dec_hidden_size)
        # Proposals for inital states
        self.input_layer_hid1 = layers.Input(enc_last_state_shape,
                                             name="initial_state_h")
        self.input_layer_hid2 = layers.Input(enc_last_state_shape,
                                             name="initial_state_c")
        self.out = self.call((self.input_layer_pos, (self.input_layer_hid1,
                                                     self.input_layer_hid2)))
        # Call init again. This is a workaround for being able to use summary
        super(DecoderOf, self).__init__(inputs=[
            self.input_layer_pos, self.input_layer_hid1, self.input_layer_hid2
        ],
                                        outputs=self.out)
Пример #12
0
    def __init__(self, vocab_size, embedding_size, gru_layers, units, name='GRU_Generator'):
        super(SequenceGenerator, self).__init__(name=name)
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.gru_layers = gru_layers
        self.units = units

        self.embed = layers.Dense(embedding_size, use_bias=False)
        self.cells = []
        for layer in range(gru_layers):
            self.cells.append(layers.GRUCell(units[layer], recurrent_dropout=0.2))
        self.dense = layers.Dense(vocab_size)
    def __init__(self, units):
        super(GRUUsingGRUCell, self).__init__()

        # [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # transform text to embedding representation
        # [b, 80] => [b, 80, 100]
        self.embedding = layers.Embedding(total_words,
                                          embedding_len,
                                          input_length=max_review_len)

        # [b, 80, 100] , h_dim: 64
        # RNN: cell1 ,cell2, cell3
        # SimpleRNN
        self.rnn_cell0 = layers.GRUCell(units, dropout=0.5)
        self.rnn_cell1 = layers.GRUCell(units, dropout=0.5)

        # fc, [b, 80, 100] => [b, 64] => [b, 1]
        self.output_layer = layers.Dense(1)
Пример #14
0
def get_encoder(hidden_size,
                vocab_size,
                dropout=0.2,
                num_layers=2,
                bsize=58,
                msize=553,
                ssize=3191,
                dsize=405):
    pieces = Input(shape=(20, ), name='pieces')

    embedding = layers.Embedding(vocab_size, 200)
    pieces_emb = embedding(pieces)
    cells = [
        layers.GRUCell(hidden_size, dropout=dropout)
        for _ in range(num_layers - 1)
    ]
    cells.append(layers.GRUCell(hidden_size))
    gru = layers.RNN(cells,
                     return_sequences=False,
                     return_state=True,
                     name='multi-gru')

    output = gru(pieces_emb)[-1]

    feature = Sequential([
        layers.Dense(hidden_size, activations.relu),
        layers.Dropout(dropout),
        layers.Dense(hidden_size, activations.relu, name='final_feature')
    ],
                         name='feature_seq')(output)

    bcate = layers.Dense(bsize, name='bcateid')(feature)
    mcate = layers.Dense(msize, name='mcateid')(feature)
    scate = layers.Dense(ssize, name='scateid')(feature)
    dcate = layers.Dense(dsize, name='dcateid')(feature)

    model = Model(inputs=pieces, outputs=[bcate, mcate, scate, dcate])
    return model
Пример #15
0
def get_encoder(hidden_size,vocab_size,
                 num_tokens=7,nlayers=2,dropout=0.2,
                 bsize=57, msize=552, ssize=3190, dsize=404):
    len_input=keras.Input(shape=(),name='len',dtype=tf.int64)
    pieces_input=[keras.Input(shape=(num_pieces,),name='piece{}'.format(i+1)) for i in range(num_tokens)]
    img_input=keras.Input(shpae=(2048,),name='img')

    embedding=layers.Embedding(vocab_size,hidden_size,mask_zero=True)
    pieces=[embedding(piece) for piece in pieces_input]
    cells=[layers.GRUCell(hidden_size,dropout=dropout) for _ in range(nlayers-1)]
    cells.append(layers.LSTMCell(hidden_size))
    lstm=layers.RNN(cells,return_sequences=False,return_state=True,name='multi-gru')



    state=lstm(pieces[0])
    states=[state[-1][-1]]

    pieces.remove(pieces[0])

    for piece in pieces:
        state=lstm(piece)
        states.append(state[-1][-1])

    result=tf.math.add_n(states)
    sent_len=tf.reshape(len_input,(-1,1))
    #sent_len=tf.tile(len_input,[1,hidden_size])
    sent_len=tf.cast(sent_len,tf.float32)
    text_feat=tf.divide(result,sent_len,name='text_feature')

    img_feat=layers.Dense(hidden_size,activations.relu,name='img_feature')(img_input)

    text_plus_img=layers.concat([text_feat,img_feat],1)

    feature=Sequential([
        layers.Dense(hidden_size,activations.relu),
        layers.Dropout(dropout),
        layers.Dense(hidden_size,activations.relu,name='final_feature')
    ],name='feature_seq')(text_plus_img)

    bcate = layers.Dense(bsize,name='bcateid')(feature)
    mcate = layers.Dense(msize,name='mcateid')(feature)
    scate = layers.Dense(ssize,name='scateid')(feature)
    dcate = layers.Dense(dsize,name='dcateid')(feature)

    inputs=[len_input,img_input]+pieces_input

    model=Model(inputs=inputs,outputs=[bcate,mcate,scate,dcate])
    return model
Пример #16
0
class MyModel(Model):

    layers.GRUCell()

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense_1 = layers.Dense(10, name="predictions")
        self.dense_2 = layers.Dense(3, activation="softmax", name="class1")
        self.dense_3 = layers.Dense(3, activation="softmax", name="class2")

    def call(self, inputs, training=None):
        x = self.dense_1(inputs)
        y1 = self.dense_2(x)
        y2 = self.dense_3(x)
        return {"y1": y1, "y2": y2}
Пример #17
0
Файл: nn.py Проект: xlnwel/d2rl
    def __init__(self, name='rssm'):
        super().__init__(name)

        self._embed_layer = layers.Dense(self._hidden_size,
                                         activation=self._activation,
                                         name='embed')
        self._cell = layers.GRUCell(self._deter_size)
        self._img_layers = mlp([self._hidden_size],
                               out_size=2 * self._stoch_size,
                               activation=self._activation,
                               name='img')
        self._obs_layers = mlp([self._hidden_size],
                               out_size=2 * self._stoch_size,
                               activation=self._activation,
                               name='obs')
Пример #18
0
    def __init__(self,
                 num_iterations,
                 num_slots,
                 slot_size,
                 mlp_hidden_size,
                 epsilon=1e-8):
        """Builds the Slot Attention module.

    Args:
      num_iterations: Number of iterations.
      num_slots: Number of slots.
      slot_size: Dimensionality of slot feature vectors.
      mlp_hidden_size: Hidden layer size of MLP.
      epsilon: Offset for attention coefficients before normalization.
    """
        super().__init__()
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.epsilon = epsilon

        self.norm_inputs = layers.LayerNormalization()
        self.norm_slots = layers.LayerNormalization()
        self.norm_mlp = layers.LayerNormalization()

        # Parameters for Gaussian init (shared by all slots).
        self.slots_mu = self.add_weight(initializer="glorot_uniform",
                                        shape=[1, 1, self.slot_size],
                                        dtype=tf.float32,
                                        name="slots_mu")
        self.slots_log_sigma = self.add_weight(initializer="glorot_uniform",
                                               shape=[1, 1, self.slot_size],
                                               dtype=tf.float32,
                                               name="slots_log_sigma")

        # Linear maps for the attention module.
        self.project_q = layers.Dense(self.slot_size, use_bias=False, name="q")
        self.project_k = layers.Dense(self.slot_size, use_bias=False, name="k")
        self.project_v = layers.Dense(self.slot_size, use_bias=False, name="v")

        # Slot update functions.
        self.gru = layers.GRUCell(self.slot_size)
        self.mlp = tf.keras.Sequential([
            layers.Dense(self.mlp_hidden_size, activation="relu"),
            layers.Dense(self.slot_size)
        ],
                                       name="mlp")
Пример #19
0
def rnn_cell(module_name):
    '''

    :param module_name:
    :return:
    '''

    # GRU   # -> The hidden dimension here is the number of hidden state in each RNN cell, a RNN cell is n-dimensional vector where n is the length of MTS
    if (module_name == 'gru'):
        rnn_cell = ll.GRUCell(units=hidden_dim, activation="tanh")
    # LSTM
    elif (module_name == 'lstm'):
        rnn_cell = ll.LSTMCell(units=hidden_dim, activation="tanh")
    # LSTM Layer Normalization
    '''elif (module_name == 'lstmLN'):
        rnn_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(num_units=hidden_dim, activation="tanh")'''
    return rnn_cell
    def __init__(self, latent_dim, n_atoms):

        super(RecurrentStateSpaceModel, self).__init__()

        self.latent_dim, self.n_atoms = latent_dim, n_atoms

        self.units = 600

        self.dense_z_prior1 = kl.Dense(self.units, activation="elu")

        self.dense_z_prior2 = kl.Dense(self.latent_dim * self.n_atoms)

        self.dense_z_post1 = kl.Dense(self.units, activation="elu")

        self.dense_z_post2 = kl.Dense(self.latent_dim * self.n_atoms)

        self.dense_h1 = kl.Dense(self.units, activation="elu")

        self.gru_cell = kl.GRUCell(self.units, activation="tanh")
Пример #21
0
    def __init__(self,
                 units: int,
                 out_dim: int,
                 shift_std: float = 0.1,
                 cell_type: str = 'lstm',
                 offdiag: bool = False):
        """Constructs a learnable multivariate normal cell.

        Args:
          units: Dimensionality of the RNN function parameters.
          out_dim: The dimensionality of the distribution.
          shift_std: Shift applied to MVN std before building the dist. Providing a shift
            toward the expected std allows the input values to be closer to 0.
          cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
          offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
        """
        super(LearnableMultivariateNormalCell, self).__init__()
        self.offdiag = offdiag
        self.output_dimensions = out_dim
        self.units = units
        if cell_type.upper().endswith('LSTM'):
            self.rnn_cell = tfkl.LSTMCell(self.units,
                                          implementation=1,
                                          name="mvncell")
            # why does the jupyter notebook version require implementation=1 but not in pycharm?
        elif cell_type.upper().endswith('GRU'):
            self.rnn_cell = tfkl.GRUCell(self.units, name="mvnell")
        elif cell_type.upper().endswith('RNN'):
            self.rnn_cell = tfkl.SimpleRNNCell(self.units, name="mvncell")
        elif cell_type.upper().endswith('GRUCLIP'):
            from indl.rnn.gru_clip import GRUClipCell
            self.rnn_cell = GRUClipCell(self.units, name="mvncell")
        else:
            raise ValueError("cell_type %s not recognized" % cell_type)

        self.loc_layer = tfkl.Dense(self.output_dimensions, name="mvncell_loc")
        n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(out_dim) - out_dim) if offdiag\
            else (tfpl.IndependentNormal.params_size(out_dim) - out_dim)
        self.scale_untransformed_layer = tfkl.Dense(n_scale_dim,
                                                    name="mvndiagcell_scale")
        self._scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)
 def __init__(self,
              vocab_size,
              max_message_length,
              embed_dim,
              hidden_dim,
              vision_module,
              flexible_message_length=False,
              activation='linear'):
     super(Sender, self).__init__(vocab_size,
                                  max_message_length,
                                  embed_dim,
                                  hidden_dim,
                                  vision_module,
                                  flexible_message_length,
                                  VtoH_activation=activation)
     self.language_module = layers.GRUCell(hidden_dim, name='GRU_layer')
     self.hidden_to_output = layers.Dense(
         vocab_size, activation='linear',
         name='hidden_to_output')  # must be linear
     if self.max_message_length > 1:
         self.embedding = layers.Embedding(vocab_size,
                                           embed_dim,
                                           name='embedding')
     self.__build()
Пример #23
0
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
"""
# Shape
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

###### Définition du modèle ######
from tensorflow.keras import layers, Sequential, Input, Model

I = Input(shape=(66, ))

E = layers.Embedding(vocab_size, vocab_size)(I)
B1 = layers.RNN(layers.GRUCell(64), return_sequences=True)(E)
DR1 = layers.Dropout(0.2)(B1)
D1 = layers.Dense(128, activation="relu")(DR1)
DR2 = layers.Dropout(0.2)(D1)
D2 = layers.Dense(64, activation="relu")(DR2)
DR3 = layers.Dropout(0.2)(D2)
O = layers.Dense(18, activation="softmax")(DR3)

model = Model(I, O)

model.summary()

model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
"""
 def build(self, input_shape):
     self.atom_dim = input_shape[0][-1]
     self.message_step = EdgeNetwork()
     self.pad_length = max(0, self.units - self.atom_dim)
     self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)
     self.built = True
Пример #25
0
grads, _ = tf.clip_by_global_norm(grads, 25)  # 全局梯度裁剪
# 利用裁剪后的梯度张量更新参数
optimizers.apply_gradients(zip(grads, model.trainable_variables))

#%%
x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :]  # 得到一个时间戳的输入
cell = layers.LSTMCell(64)  # 创建Cell
# 初始化状态和输出List,[h,c]
state = [tf.zeros([2, 64]), tf.zeros([2, 64])]
out, state = cell(xt, state)  # 前向计算
id(out), id(state[0]), id(state[1])

#%%
net = layers.LSTM(4)
net.build(input_shape=(None, 5, 3))
net.trainable_variables
#%%

net = layers.GRU(4)
net.build(input_shape=(None, 5, 3))
net.trainable_variables

#%%
# 初始化状态向量
h = [tf.zeros([2, 64])]
cell = layers.GRUCell(64)  # 新建GRU Cell
for xt in tf.unstack(x, axis=1):
    out, h = cell(xt, h)
out.shape