예제 #1
0
 def trainable_initial_state(self, batch_size):
     """
     Create a trainable initial state for the MultiSkipGRUCell
     :param batch_size: number of samples per batch
     :return: list of tensors and SkipGRUStateTuple
     """
     initial_states = []
     for idx in range(self._num_layers - 1):
         with tf.variable_scope('layer_%d' % (idx + 1)):
             with tf.variable_scope('initial_h'):
                 initial_h = rnn_ops.create_initial_state(
                     batch_size, self._num_units[idx])
             initial_states.append(initial_h)
     with tf.variable_scope('layer_%d' % self._num_layers):
         with tf.variable_scope('initial_h'):
             initial_h = rnn_ops.create_initial_state(
                 batch_size, self._num_units[-1])
         with tf.variable_scope('initial_update_prob'):
             initial_update_prob = rnn_ops.create_initial_state(
                 batch_size,
                 1,
                 trainable=False,
                 initializer=tf.ones_initializer())
         with tf.variable_scope('initial_cum_update_prob'):
             initial_cum_update_prob = rnn_ops.create_initial_state(
                 batch_size,
                 1,
                 trainable=False,
                 initializer=tf.zeros_initializer())
         initial_states.append(
             SkipGRUStateTuple(initial_h, initial_update_prob,
                               initial_cum_update_prob))
     return initial_states
예제 #2
0
 def trainable_initial_state(self, batch_size):
     """
     Create a trainable initial state for the SkipLSTMCell
     :param batch_size: number of samples per batch
     :return: SkipLSTMStateTuple
     """
     print("SkipLSTMCell_trainable_initial_state_called", batch_size)
     with tf.variable_scope(f'initial_c_{self.layer}'):
         initial_c = rnn_ops.create_initial_state(batch_size,
                                                  self._num_units)
     with tf.variable_scope(f'initial_h_{self.layer}'):
         initial_h = rnn_ops.create_initial_state(batch_size,
                                                  self._num_units)
     with tf.variable_scope(f'initial_update_prob_{self.layer}'):
         initial_update_prob = rnn_ops.create_initial_state(
             batch_size,
             1,
             trainable=False,
             initializer=tf.ones_initializer())
     with tf.variable_scope(f'initial_cum_update_prob_{self.layer}'):
         initial_cum_update_prob = rnn_ops.create_initial_state(
             batch_size,
             1,
             trainable=False,
             initializer=tf.zeros_initializer())
     return SkipLSTMStateTuple(initial_c, initial_h, initial_update_prob,
                               initial_cum_update_prob)
예제 #3
0
 def trainable_initial_state(self, batch_size):
     """
     Create a trainable initial state for the BasicGRUCell
     :param batch_size: number of samples per batch
     :return: tensor with shape [batch_size, self.state_size]
     """
     with tf.variable_scope('basic_initial_h'):
         initial_h = rnn_ops.create_initial_state(batch_size,
                                                  self._num_units)
     return initial_h
예제 #4
0
 def trainable_initial_state(self, batch_size):
     """
     Create a trainable initial state for the SkipGRUCell
     :param batch_size: number of samples per batch
     :return: SkipGRUStateTuple
     """
     with tf.variable_scope('initial_h'):
         initial_h = rnn_ops.create_initial_state(batch_size,
                                                  self._num_units)
     with tf.variable_scope('initial_update_prob'):
         initial_update_prob = rnn_ops.create_initial_state(
             batch_size,
             1,
             trainable=False,
             initializer=tf.ones_initializer())
     with tf.variable_scope('initial_cum_update_prob'):
         initial_cum_update_prob = rnn_ops.create_initial_state(
             batch_size,
             1,
             trainable=False,
             initializer=tf.zeros_initializer())
     return SkipGRUStateTuple(initial_h, initial_update_prob,
                              initial_cum_update_prob)