Exemplo n.º 1
0
def rnn_cell(rnn_cell_size, dropout_keep_prob, residual, is_training=True):
    """Builds an LSTMBlockCell based on the given parameters."""
    dropout_keep_prob = dropout_keep_prob if is_training else 1.0
    cells = []
    for i in range(len(rnn_cell_size)):
        cell = contrib_rnn.LSTMBlockCell(rnn_cell_size[i])
        if residual:
            cell = rnn.ResidualWrapper(cell)
            if i == 0 or rnn_cell_size[i] != rnn_cell_size[i - 1]:
                cell = contrib_rnn.InputProjectionWrapper(
                    cell, rnn_cell_size[i])
        cell = rnn.DropoutWrapper(cell, input_keep_prob=dropout_keep_prob)
        cells.append(cell)
    return rnn.MultiRNNCell(cells)
Exemplo n.º 2
0
 def testInputProjectionWrapper(self):
     with self.cached_session() as sess:
         with tf.variable_scope("root",
                                initializer=tf.constant_initializer(0.5)):
             x = tf.zeros([1, 2])
             m = tf.zeros([1, 3])
             cell = contrib_rnn.InputProjectionWrapper(rnn_cell.GRUCell(3),
                                                       num_proj=3)
             g, new_m = cell(x, m)
             sess.run([tf.global_variables_initializer()])
             res = sess.run(
                 [g, new_m], {
                     x.name: np.array([[1., 1.]]),
                     m.name: np.array([[0.1, 0.1, 0.1]])
                 })
             self.assertEqual(res[1].shape, (1, 3))
             # The numbers in results were not calculated, this is just a smoke test.
             self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def make_rnn_cell(rnn_layer_sizes,
                  dropout_keep_prob=1.0,
                  attn_length=0,
                  base_cell=rnn.BasicLSTMCell,
                  residual_connections=False):
    """Makes a RNN cell from the given hyperparameters.

  Args:
    rnn_layer_sizes: A list of integer sizes (in units) for each layer of the
        RNN.
    dropout_keep_prob: The float probability to keep the output of any given
        sub-cell.
    attn_length: The size of the attention vector.
    base_cell: The base rnn.RNNCell to use for sub-cells.
    residual_connections: Whether or not to use residual connections (via
        rnn.ResidualWrapper).

  Returns:
      A rnn.MultiRNNCell based on the given hyperparameters.
  """
    cells = []
    for i in range(len(rnn_layer_sizes)):
        cell = base_cell(rnn_layer_sizes[i])
        if attn_length and not cells:
            # Add attention wrapper to first layer.
            cell = contrib_rnn.AttentionCellWrapper(cell,
                                                    attn_length,
                                                    state_is_tuple=True)
        if residual_connections:
            cell = rnn.ResidualWrapper(cell)
            if i == 0 or rnn_layer_sizes[i] != rnn_layer_sizes[i - 1]:
                cell = contrib_rnn.InputProjectionWrapper(
                    cell, rnn_layer_sizes[i])
        cell = rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep_prob)
        cells.append(cell)

    cell = rnn.MultiRNNCell(cells)

    return cell