def rnn_cell(rnn_cell_size, dropout_keep_prob, residual, is_training=True): """Builds an LSTMBlockCell based on the given parameters.""" dropout_keep_prob = dropout_keep_prob if is_training else 1.0 cells = [] for i in range(len(rnn_cell_size)): cell = contrib_rnn.LSTMBlockCell(rnn_cell_size[i]) if residual: cell = rnn.ResidualWrapper(cell) if i == 0 or rnn_cell_size[i] != rnn_cell_size[i - 1]: cell = contrib_rnn.InputProjectionWrapper( cell, rnn_cell_size[i]) cell = rnn.DropoutWrapper(cell, input_keep_prob=dropout_keep_prob) cells.append(cell) return rnn.MultiRNNCell(cells)
def testInputProjectionWrapper(self): with self.cached_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 2]) m = tf.zeros([1, 3]) cell = contrib_rnn.InputProjectionWrapper(rnn_cell.GRUCell(3), num_proj=3) g, new_m = cell(x, m) sess.run([tf.global_variables_initializer()]) res = sess.run( [g, new_m], { x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1, 0.1]]) }) self.assertEqual(res[1].shape, (1, 3)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def make_rnn_cell(rnn_layer_sizes, dropout_keep_prob=1.0, attn_length=0, base_cell=rnn.BasicLSTMCell, residual_connections=False): """Makes a RNN cell from the given hyperparameters. Args: rnn_layer_sizes: A list of integer sizes (in units) for each layer of the RNN. dropout_keep_prob: The float probability to keep the output of any given sub-cell. attn_length: The size of the attention vector. base_cell: The base rnn.RNNCell to use for sub-cells. residual_connections: Whether or not to use residual connections (via rnn.ResidualWrapper). Returns: A rnn.MultiRNNCell based on the given hyperparameters. """ cells = [] for i in range(len(rnn_layer_sizes)): cell = base_cell(rnn_layer_sizes[i]) if attn_length and not cells: # Add attention wrapper to first layer. cell = contrib_rnn.AttentionCellWrapper(cell, attn_length, state_is_tuple=True) if residual_connections: cell = rnn.ResidualWrapper(cell) if i == 0 or rnn_layer_sizes[i] != rnn_layer_sizes[i - 1]: cell = contrib_rnn.InputProjectionWrapper( cell, rnn_layer_sizes[i]) cell = rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep_prob) cells.append(cell) cell = rnn.MultiRNNCell(cells) return cell