예제 #1
0
  def testAttentionCellWrapperFailures(self):
    with self.assertRaisesRegexp(TypeError,
                                 "The parameter cell is not RNNCell."):
      rnn_cell.AttentionCellWrapper(None, 0)

    num_units = 8
    for state_is_tuple in [False, True]:
      with ops.Graph().as_default():
        lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
            num_units, state_is_tuple=state_is_tuple)
        with self.assertRaisesRegexp(
            ValueError, "attn_length should be greater than zero, got 0"):
          rnn_cell.AttentionCellWrapper(
              lstm_cell, 0, state_is_tuple=state_is_tuple)
        with self.assertRaisesRegexp(
            ValueError, "attn_length should be greater than zero, got -1"):
          rnn_cell.AttentionCellWrapper(
              lstm_cell, -1, state_is_tuple=state_is_tuple)
      with ops.Graph().as_default():
        lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
            num_units, state_is_tuple=True)
        with self.assertRaisesRegexp(
            ValueError, "Cell returns tuple of states, but the flag "
            "state_is_tuple is not set. State size is: *"):
          rnn_cell.AttentionCellWrapper(lstm_cell, 4, state_is_tuple=False)
예제 #2
0
 def testAttentionCellWrapperZeros(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     input_size = 4
     for state_is_tuple in [False, True]:
         with ops.Graph().as_default():
             with self.test_session() as sess:
                 with variable_scope.variable_scope("state_is_tuple_" +
                                                    str(state_is_tuple)):
                     lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = rnn_cell.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = array_ops.zeros([batch_size, num_units],
                                                 dtype=np.float32)
                         attn_state_zeros = array_ops.zeros(
                             [batch_size, attn_length * num_units],
                             dtype=np.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = array_ops.zeros([
                             batch_size, num_units * 2 +
                             attn_length * num_units + num_units
                         ],
                                                      dtype=np.float32)
                     inputs = array_ops.zeros([batch_size, input_size],
                                              dtype=dtypes.float32)
                     output, state = cell(inputs, zero_state)
                     self.assertEquals(output.get_shape(),
                                       [batch_size, num_units])
                     if state_is_tuple:
                         self.assertEquals(len(state), 3)
                         self.assertEquals(len(state[0]), 2)
                         self.assertEquals(state[0][0].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(state[0][1].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(state[1].get_shape(),
                                           [batch_size, num_units])
                         self.assertEquals(
                             state[2].get_shape(),
                             [batch_size, attn_length * num_units])
                         tensors = [output] + list(state)
                     else:
                         self.assertEquals(state.get_shape(), [
                             batch_size, num_units * 2 + num_units +
                             attn_length * num_units
                         ])
                         tensors = [output, state]
                     zero_result = sum([
                         math_ops.reduce_sum(math_ops.abs(x))
                         for x in tensors
                     ])
                     sess.run(variables.global_variables_initializer())
                     self.assertTrue(sess.run(zero_result) < 1e-6)
예제 #3
0
 def testAttentionCellWrapperCorrectResult(self):
   num_units = 4
   attn_length = 6
   batch_size = 2
   expected_output = np.array(
       [[1.068372, 0.45496, -0.678277, 0.340538],
        [1.018088, 0.378983, -0.572179, 0.268591]],
       dtype=np.float32)
   expected_state = np.array(
       [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
         0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
         0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
         0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
         0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
         0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
         0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
         0.51843399],
        [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
         0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
         0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
         0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
         0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
         0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
         0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
         0.70582712]],
       dtype=np.float32)
   seed = 12345
   random_seed.set_random_seed(seed)
   for state_is_tuple in [False, True]:
     with session.Session() as sess:
       with variable_scope.variable_scope(
           "state_is_tuple", reuse=state_is_tuple,
           initializer=init_ops.glorot_uniform_initializer()):
         lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
             num_units, state_is_tuple=state_is_tuple)
         cell = rnn_cell.AttentionCellWrapper(
             lstm_cell, attn_length, state_is_tuple=state_is_tuple)
         zeros1 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 1)
         zeros2 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 2)
         zeros3 = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 3)
         attn_state_zeros = random_ops.random_uniform(
             (batch_size, attn_length * num_units), 0.0, 1.0, seed=seed + 4)
         zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
         if not state_is_tuple:
           zero_state = array_ops.concat([
               zero_state[0][0], zero_state[0][1], zero_state[1], zero_state[2]
           ], 1)
         inputs = random_ops.random_uniform(
             (batch_size, num_units), 0.0, 1.0, seed=seed + 5)
         output, state = cell(inputs, zero_state)
         if state_is_tuple:
           state = array_ops.concat(
               [state[0][0], state[0][1], state[1], state[2]], 1)
         sess.run(variables.global_variables_initializer())
         self.assertAllClose(sess.run(output), expected_output)
         self.assertAllClose(sess.run(state), expected_state)
예제 #4
0
 def testAttentionCellWrapperValues(self):
     num_units = 8
     attn_length = 16
     batch_size = 3
     for state_is_tuple in [False, True]:
         with ops.Graph().as_default():
             with self.test_session() as sess:
                 with variable_scope.variable_scope("state_is_tuple_" +
                                                    str(state_is_tuple)):
                     lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                         num_units, state_is_tuple=state_is_tuple)
                     cell = rnn_cell.AttentionCellWrapper(
                         lstm_cell,
                         attn_length,
                         state_is_tuple=state_is_tuple)
                     if state_is_tuple:
                         zeros = constant_op.constant(0.1 * np.ones(
                             [batch_size, num_units], dtype=np.float32),
                                                      dtype=dtypes.float32)
                         attn_state_zeros = constant_op.constant(
                             0.1 *
                             np.ones([batch_size, attn_length * num_units],
                                     dtype=np.float32),
                             dtype=dtypes.float32)
                         zero_state = ((zeros, zeros), zeros,
                                       attn_state_zeros)
                     else:
                         zero_state = constant_op.constant(
                             0.1 * np.ones([
                                 batch_size, num_units * 2 + num_units +
                                 attn_length * num_units
                             ],
                                           dtype=np.float32),
                             dtype=dtypes.float32)
                     inputs = constant_op.constant(np.array(
                         [[1., 1., 1., 1.], [2., 2., 2., 2.],
                          [3., 3., 3., 3.]],
                         dtype=np.float32),
                                                   dtype=dtypes.float32)
                     output, state = cell(inputs, zero_state)
                     if state_is_tuple:
                         concat_state = array_ops.concat(
                             [state[0][0], state[0][1], state[1], state[2]],
                             1)
                     else:
                         concat_state = state
                     sess.run(variables.global_variables_initializer())
                     output, state = sess.run([output, concat_state])
                     # Different inputs so different outputs and states
                     for i in range(1, batch_size):
                         self.assertTrue(
                             float(
                                 np.linalg.norm((output[0, :] -
                                                 output[i, :]))) > 1e-6)
                         self.assertTrue(
                             float(
                                 np.linalg.norm((state[0, :] -
                                                 state[i, :]))) > 1e-6)
 def testAttentionCellWrapperCorrectResult(self):
     num_units = 4
     attn_length = 6
     batch_size = 2
     expected_output = np.array([[0.955392, 0.408507, -0.60122, 0.270718],
                                 [0.903681, 0.331165, -0.500238, 0.224052]],
                                dtype=np.float32)
     expected_state = np.array(
         [[
             0.81331915, 0.32036272, 0.28079176, 1.08888793, 0.41264394,
             0.1062041, 0.10444493, 0.32050529, 0.64655536, 0.70794445,
             0.51896095, 0.31809306, 0.58086717, 0.49446869, 0.7641536,
             0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
             0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
             0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
             0.99211812, 0.12295902, 1.01412082, 0.33123279, -0.71114945,
             0.40583119
         ],
          [
              0.59962207, 0.42597458, -0.22491696, 0.98063421, 0.32548007,
              0.11623692, -0.10100613, 0.27708149, 0.76956916, 0.6360054,
              0.51719815, 0.50458527, 0.73000264, 0.66986895, 0.73576689,
              0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
              0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
              0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
              0.36127412, 0.12125921, 0.99780077, 0.31886846, -0.67595094,
              0.56531656
          ]],
         dtype=np.float32)
     seed = 12345
     random_seed.set_random_seed(seed)
     for state_is_tuple in [False, True]:
         with session.Session() as sess:
             with variable_scope.variable_scope("state_is_tuple",
                                                reuse=state_is_tuple):
                 lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
                     num_units, state_is_tuple=state_is_tuple)
                 cell = rnn_cell.AttentionCellWrapper(
                     lstm_cell, attn_length, state_is_tuple=state_is_tuple)
                 zeros1 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 1)
                 zeros2 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 2)
                 zeros3 = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 3)
                 attn_state_zeros = random_ops.random_uniform(
                     (batch_size, attn_length * num_units),
                     0.0,
                     1.0,
                     seed=seed + 4)
                 zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
                 if not state_is_tuple:
                     zero_state = array_ops.concat_v2([
                         zero_state[0][0], zero_state[0][1], zero_state[1],
                         zero_state[2]
                     ], 1)
                 inputs = random_ops.random_uniform((batch_size, num_units),
                                                    0.0,
                                                    1.0,
                                                    seed=seed + 5)
                 output, state = cell(inputs, zero_state)
                 if state_is_tuple:
                     state = array_ops.concat_v2(
                         [state[0][0], state[0][1], state[1], state[2]], 1)
                 sess.run(variables.global_variables_initializer())
                 self.assertAllClose(sess.run(output), expected_output)
                 self.assertAllClose(sess.run(state), expected_state)