示例#1
0
    def testAttentionCellWrapperFailures(self):
        with self.assertRaisesRegexp(
                TypeError, contrib_rnn.ASSERT_LIKE_RNNCELL_ERROR_REGEXP):
            contrib_rnn.AttentionCellWrapper(None, 0)

        num_units = 8
        for state_is_tuple in [False, True]:
            with tf.Graph().as_default():
                lstm_cell = rnn_cell.BasicLSTMCell(
                    num_units, state_is_tuple=state_is_tuple)
                with self.assertRaisesRegexp(
                        ValueError,
                        "attn_length should be greater than zero, got 0"):
                    contrib_rnn.AttentionCellWrapper(
                        lstm_cell, 0, state_is_tuple=state_is_tuple)
                with self.assertRaisesRegexp(
                        ValueError,
                        "attn_length should be greater than zero, got -1"):
                    contrib_rnn.AttentionCellWrapper(
                        lstm_cell, -1, state_is_tuple=state_is_tuple)
            with tf.Graph().as_default():
                lstm_cell = rnn_cell.BasicLSTMCell(num_units,
                                                   state_is_tuple=True)
                with self.assertRaisesRegexp(
                        ValueError,
                        "Cell returns tuple of states, but the flag "
                        "state_is_tuple is not set. State size is: *"):
                    contrib_rnn.AttentionCellWrapper(lstm_cell,
                                                     4,
                                                     state_is_tuple=False)
示例#2
0
 def testAttentionCellWrapperZeros(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     input_size = 4
     for state_is_tuple in [False, True]:
         with tf.Graph().as_default():
             with self.cached_session() as sess:
                 with tf.variable_scope("state_is_tuple_" +
                                        str(state_is_tuple)):
                     lstm_cell = rnn_cell.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = contrib_rnn.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = tf.zeros([batch_size, num_units],
                                          dtype=np.float32)
                         attn_state_zeros = tf.zeros(
                             [batch_size, attn_length * num_units],
                             dtype=np.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = tf.zeros([
                             batch_size, num_units * 2 +
                             attn_length * num_units + num_units
                         ],
                                               dtype=np.float32)
                     inputs = tf.zeros([batch_size, input_size],
                                       dtype=tf.float32)
                     output, state = cell(inputs, zero_state)
                     self.assertEqual(output.get_shape(),
                                      [batch_size, num_units])
                     if state_is_tuple:
                         self.assertEqual(len(state), 3)
                         self.assertEqual(len(state[0]), 2)
                         self.assertEqual(state[0][0].get_shape(),
                                          [batch_size, num_units])
                         self.assertEqual(state[0][1].get_shape(),
                                          [batch_size, num_units])
                         self.assertEqual(state[1].get_shape(),
                                          [batch_size, num_units])
                         self.assertEqual(
                             state[2].get_shape(),
                             [batch_size, attn_length * num_units])
                         tensors = [output] + list(state)
                     else:
                         self.assertEqual(state.get_shape(), [
                             batch_size, num_units * 2 + num_units +
                             attn_length * num_units
                         ])
                         tensors = [output, state]
                     zero_result = sum(
                         [tf.reduce_sum(tf.abs(x)) for x in tensors])
                     sess.run(tf.global_variables_initializer())
                     self.assertLess(sess.run(zero_result), 1e-6)
示例#3
0
 def testAttentionCellWrapperValues(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     for state_is_tuple in [False, True]:
         with tf.Graph().as_default():
             with self.cached_session() as sess:
                 with tf.variable_scope("state_is_tuple_" +
                                        str(state_is_tuple)):
                     lstm_cell = rnn_cell.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = contrib_rnn.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = tf.constant(0.1 * np.ones(
                             [batch_size, num_units], dtype=np.float32),
                                             dtype=tf.float32)
                         attn_state_zeros = tf.constant(
                             0.1 *
                             np.ones([batch_size, attn_length * num_units],
                                     dtype=np.float32),
                             dtype=tf.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = tf.constant(
                             0.1 * np.ones([
                                 batch_size, num_units * 2 + num_units +
                                 attn_length * num_units
                             ],
                                           dtype=np.float32),
                             dtype=tf.float32)
                     inputs = tf.constant(np.array(
                         [[1., 1., 1., 1.], [2., 2., 2., 2.],
                          [3., 3., 3., 3.]],
                         dtype=np.float32),
                                          dtype=tf.float32)
                     output, state = cell(inputs, zero_state)
                     if state_is_tuple:
                         concat_state = tf.concat(
                             [state[0][0], state[0][1], state[1], state[2]],
                             1)
                     else:
                         concat_state = state
                     sess.run(tf.global_variables_initializer())
                     output, state = sess.run([output, concat_state])
                     # Different inputs so different outputs and states
                     for i in range(1, batch_size):
                         self.assertGreater(
                             float(
                                 np.linalg.norm(
                                     (output[0, :] - output[i, :]))), 1e-6)
                         self.assertGreater(
                             float(
                                 np.linalg.norm(
                                     (state[0, :] - state[i, :]))), 1e-6)
def make_rnn_cell(rnn_layer_sizes,
                  dropout_keep_prob=1.0,
                  attn_length=0,
                  base_cell=rnn.BasicLSTMCell,
                  residual_connections=False):
    """Makes a RNN cell from the given hyperparameters.

  Args:
    rnn_layer_sizes: A list of integer sizes (in units) for each layer of the
        RNN.
    dropout_keep_prob: The float probability to keep the output of any given
        sub-cell.
    attn_length: The size of the attention vector.
    base_cell: The base rnn.RNNCell to use for sub-cells.
    residual_connections: Whether or not to use residual connections (via
        rnn.ResidualWrapper).

  Returns:
      A rnn.MultiRNNCell based on the given hyperparameters.
  """
    cells = []
    for i in range(len(rnn_layer_sizes)):
        cell = base_cell(rnn_layer_sizes[i])
        if attn_length and not cells:
            # Add attention wrapper to first layer.
            cell = contrib_rnn.AttentionCellWrapper(cell,
                                                    attn_length,
                                                    state_is_tuple=True)
        if residual_connections:
            cell = rnn.ResidualWrapper(cell)
            if i == 0 or rnn_layer_sizes[i] != rnn_layer_sizes[i - 1]:
                cell = contrib_rnn.InputProjectionWrapper(
                    cell, rnn_layer_sizes[i])
        cell = rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep_prob)
        cells.append(cell)

    cell = rnn.MultiRNNCell(cells)

    return cell
示例#5
0
 def _testAttentionCellWrapperCorrectResult(self):
     num_units = 4
     attn_length = 6
     batch_size = 2
     expected_output = np.array([[1.068372, 0.45496, -0.678277, 0.340538],
                                 [1.018088, 0.378983, -0.572179, 0.268591]],
                                dtype=np.float32)
     expected_state = np.array(
         [[
             0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
             0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
             0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
             0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
             0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
             0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
             0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
             0.51843399
         ],
          [
              0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
              0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
              0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
              0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
              0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
              0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
              0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
              0.70582712
          ]],
         dtype=np.float32)
     seed = 12345
     tf.set_random_seed(seed)
     rnn_scope = None
     for state_is_tuple in [False, True]:
         with tf.Session() as sess:
             with tf.variable_scope(
                     "state_is_tuple",
                     reuse=state_is_tuple,
                     initializer=tf.glorot_uniform_initializer()):
                 lstm_cell = rnn_cell.BasicLSTMCell(
                     num_units, state_is_tuple=state_is_tuple)
                 cell = contrib_rnn.AttentionCellWrapper(
                     lstm_cell, attn_length, state_is_tuple=state_is_tuple)
                 # This is legacy behavior to preserve the test.  Weight
                 # sharing no longer works by creating a new RNNCell in the
                 # same variable scope; so here we restore the scope of the
                 # RNNCells after the first use below.
                 if rnn_scope is not None:
                     (cell._scope, lstm_cell._scope) = rnn_scope  # pylint: disable=protected-access,unpacking-non-sequence
                 zeros1 = tf.random_uniform((batch_size, num_units),
                                            0.0,
                                            1.0,
                                            seed=seed + 1)
                 zeros2 = tf.random_uniform((batch_size, num_units),
                                            0.0,
                                            1.0,
                                            seed=seed + 2)
                 zeros3 = tf.random_uniform((batch_size, num_units),
                                            0.0,
                                            1.0,
                                            seed=seed + 3)
                 attn_state_zeros = tf.random_uniform(
                     (batch_size, attn_length * num_units),
                     0.0,
                     1.0,
                     seed=seed + 4)
                 zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
                 if not state_is_tuple:
                     zero_state = tf.concat([
                         zero_state[0][0], zero_state[0][1], zero_state[1],
                         zero_state[2]
                     ], 1)
                 inputs = tf.random_uniform((batch_size, num_units),
                                            0.0,
                                            1.0,
                                            seed=seed + 5)
                 output, state = cell(inputs, zero_state)
                 # This is legacy behavior to preserve the test.  Weight
                 # sharing no longer works by creating a new RNNCell in the
                 # same variable scope; so here we store the scope of the
                 # first RNNCell for reuse above.
                 if rnn_scope is None:
                     rnn_scope = (cell._scope, lstm_cell._scope)  # pylint: disable=protected-access
                 if state_is_tuple:
                     state = tf.concat(
                         [state[0][0], state[0][1], state[1], state[2]], 1)
                 sess.run(tf.global_variables_initializer())
                 self.assertAllClose(sess.run(output), expected_output)
                 self.assertAllClose(sess.run(state), expected_state)