def trainable_initial_state(self, batch_size): """ Create a trainable initial state for the MultiSkipGRUCell :param batch_size: number of samples per batch :return: list of tensors and SkipGRUStateTuple """ initial_states = [] for idx in range(self._num_layers - 1): with tf.variable_scope('layer_%d' % (idx + 1)): with tf.variable_scope('initial_h'): initial_h = rnn_ops.create_initial_state( batch_size, self._num_units[idx]) initial_states.append(initial_h) with tf.variable_scope('layer_%d' % self._num_layers): with tf.variable_scope('initial_h'): initial_h = rnn_ops.create_initial_state( batch_size, self._num_units[-1]) with tf.variable_scope('initial_update_prob'): initial_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.ones_initializer()) with tf.variable_scope('initial_cum_update_prob'): initial_cum_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.zeros_initializer()) initial_states.append( SkipGRUStateTuple(initial_h, initial_update_prob, initial_cum_update_prob)) return initial_states
def trainable_initial_state(self, batch_size): """ Create a trainable initial state for the SkipLSTMCell :param batch_size: number of samples per batch :return: SkipLSTMStateTuple """ print("SkipLSTMCell_trainable_initial_state_called", batch_size) with tf.variable_scope(f'initial_c_{self.layer}'): initial_c = rnn_ops.create_initial_state(batch_size, self._num_units) with tf.variable_scope(f'initial_h_{self.layer}'): initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) with tf.variable_scope(f'initial_update_prob_{self.layer}'): initial_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.ones_initializer()) with tf.variable_scope(f'initial_cum_update_prob_{self.layer}'): initial_cum_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.zeros_initializer()) return SkipLSTMStateTuple(initial_c, initial_h, initial_update_prob, initial_cum_update_prob)
def trainable_initial_state(self, batch_size): """ Create a trainable initial state for the BasicGRUCell :param batch_size: number of samples per batch :return: tensor with shape [batch_size, self.state_size] """ with tf.variable_scope('basic_initial_h'): initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) return initial_h
def trainable_initial_state(self, batch_size): """ Create a trainable initial state for the SkipGRUCell :param batch_size: number of samples per batch :return: SkipGRUStateTuple """ with tf.variable_scope('initial_h'): initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) with tf.variable_scope('initial_update_prob'): initial_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.ones_initializer()) with tf.variable_scope('initial_cum_update_prob'): initial_cum_update_prob = rnn_ops.create_initial_state( batch_size, 1, trainable=False, initializer=tf.zeros_initializer()) return SkipGRUStateTuple(initial_h, initial_update_prob, initial_cum_update_prob)