def testAttentionCellWrapperFailures(self): with self.assertRaisesRegexp( TypeError, contrib_rnn.ASSERT_LIKE_RNNCELL_ERROR_REGEXP): contrib_rnn.AttentionCellWrapper(None, 0) num_units = 8 for state_is_tuple in [False, True]: with tf.Graph().as_default(): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) with self.assertRaisesRegexp( ValueError, "attn_length should be greater than zero, got 0"): contrib_rnn.AttentionCellWrapper( lstm_cell, 0, state_is_tuple=state_is_tuple) with self.assertRaisesRegexp( ValueError, "attn_length should be greater than zero, got -1"): contrib_rnn.AttentionCellWrapper( lstm_cell, -1, state_is_tuple=state_is_tuple) with tf.Graph().as_default(): lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True) with self.assertRaisesRegexp( ValueError, "Cell returns tuple of states, but the flag " "state_is_tuple is not set. State size is: *"): contrib_rnn.AttentionCellWrapper(lstm_cell, 4, state_is_tuple=False)
def testAttentionCellWrapperZeros(self): num_units = 8 attn_length = 16 batch_size = 3 input_size = 4 for state_is_tuple in [False, True]: with tf.Graph().as_default(): with self.cached_session() as sess: with tf.variable_scope("state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) cell = contrib_rnn.AttentionCellWrapper( lstm_cell, attn_length, state_is_tuple=state_is_tuple) if state_is_tuple: zeros = tf.zeros([batch_size, num_units], dtype=np.float32) attn_state_zeros = tf.zeros( [batch_size, attn_length * num_units], dtype=np.float32) zero_state = ((zeros, zeros), zeros, attn_state_zeros) else: zero_state = tf.zeros([ batch_size, num_units * 2 + attn_length * num_units + num_units ], dtype=np.float32) inputs = tf.zeros([batch_size, input_size], dtype=tf.float32) output, state = cell(inputs, zero_state) self.assertEqual(output.get_shape(), [batch_size, num_units]) if state_is_tuple: self.assertEqual(len(state), 3) self.assertEqual(len(state[0]), 2) self.assertEqual(state[0][0].get_shape(), [batch_size, num_units]) self.assertEqual(state[0][1].get_shape(), [batch_size, num_units]) self.assertEqual(state[1].get_shape(), [batch_size, num_units]) self.assertEqual( state[2].get_shape(), [batch_size, attn_length * num_units]) tensors = [output] + list(state) else: self.assertEqual(state.get_shape(), [ batch_size, num_units * 2 + num_units + attn_length * num_units ]) tensors = [output, state] zero_result = sum( [tf.reduce_sum(tf.abs(x)) for x in tensors]) sess.run(tf.global_variables_initializer()) self.assertLess(sess.run(zero_result), 1e-6)
def testAttentionCellWrapperValues(self): num_units = 8 attn_length = 16 batch_size = 3 for state_is_tuple in [False, True]: with tf.Graph().as_default(): with self.cached_session() as sess: with tf.variable_scope("state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) cell = contrib_rnn.AttentionCellWrapper( lstm_cell, attn_length, state_is_tuple=state_is_tuple) if state_is_tuple: zeros = tf.constant(0.1 * np.ones( [batch_size, num_units], dtype=np.float32), dtype=tf.float32) attn_state_zeros = tf.constant( 0.1 * np.ones([batch_size, attn_length * num_units], dtype=np.float32), dtype=tf.float32) zero_state = ((zeros, zeros), zeros, attn_state_zeros) else: zero_state = tf.constant( 0.1 * np.ones([ batch_size, num_units * 2 + num_units + attn_length * num_units ], dtype=np.float32), dtype=tf.float32) inputs = tf.constant(np.array( [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], dtype=np.float32), dtype=tf.float32) output, state = cell(inputs, zero_state) if state_is_tuple: concat_state = tf.concat( [state[0][0], state[0][1], state[1], state[2]], 1) else: concat_state = state sess.run(tf.global_variables_initializer()) output, state = sess.run([output, concat_state]) # Different inputs so different outputs and states for i in range(1, batch_size): self.assertGreater( float( np.linalg.norm( (output[0, :] - output[i, :]))), 1e-6) self.assertGreater( float( np.linalg.norm( (state[0, :] - state[i, :]))), 1e-6)
def make_rnn_cell(rnn_layer_sizes, dropout_keep_prob=1.0, attn_length=0, base_cell=rnn.BasicLSTMCell, residual_connections=False): """Makes a RNN cell from the given hyperparameters. Args: rnn_layer_sizes: A list of integer sizes (in units) for each layer of the RNN. dropout_keep_prob: The float probability to keep the output of any given sub-cell. attn_length: The size of the attention vector. base_cell: The base rnn.RNNCell to use for sub-cells. residual_connections: Whether or not to use residual connections (via rnn.ResidualWrapper). Returns: A rnn.MultiRNNCell based on the given hyperparameters. """ cells = [] for i in range(len(rnn_layer_sizes)): cell = base_cell(rnn_layer_sizes[i]) if attn_length and not cells: # Add attention wrapper to first layer. cell = contrib_rnn.AttentionCellWrapper(cell, attn_length, state_is_tuple=True) if residual_connections: cell = rnn.ResidualWrapper(cell) if i == 0 or rnn_layer_sizes[i] != rnn_layer_sizes[i - 1]: cell = contrib_rnn.InputProjectionWrapper( cell, rnn_layer_sizes[i]) cell = rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep_prob) cells.append(cell) cell = rnn.MultiRNNCell(cells) return cell
def _testAttentionCellWrapperCorrectResult(self): num_units = 4 attn_length = 6 batch_size = 2 expected_output = np.array([[1.068372, 0.45496, -0.678277, 0.340538], [1.018088, 0.378983, -0.572179, 0.268591]], dtype=np.float32) expected_state = np.array( [[ 0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962, 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077, 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536, 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063, 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228, 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432, 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152, 0.51843399 ], [ 0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637, 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857, 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689, 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957, 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421, 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971, 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457, 0.70582712 ]], dtype=np.float32) seed = 12345 tf.set_random_seed(seed) rnn_scope = None for state_is_tuple in [False, True]: with tf.Session() as sess: with tf.variable_scope( "state_is_tuple", reuse=state_is_tuple, initializer=tf.glorot_uniform_initializer()): lstm_cell = rnn_cell.BasicLSTMCell( num_units, state_is_tuple=state_is_tuple) cell = contrib_rnn.AttentionCellWrapper( lstm_cell, attn_length, state_is_tuple=state_is_tuple) # This is legacy behavior to preserve the test. Weight # sharing no longer works by creating a new RNNCell in the # same variable scope; so here we restore the scope of the # RNNCells after the first use below. if rnn_scope is not None: (cell._scope, lstm_cell._scope) = rnn_scope # pylint: disable=protected-access,unpacking-non-sequence zeros1 = tf.random_uniform((batch_size, num_units), 0.0, 1.0, seed=seed + 1) zeros2 = tf.random_uniform((batch_size, num_units), 0.0, 1.0, seed=seed + 2) zeros3 = tf.random_uniform((batch_size, num_units), 0.0, 1.0, seed=seed + 3) attn_state_zeros = tf.random_uniform( (batch_size, attn_length * num_units), 0.0, 1.0, seed=seed + 4) zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros) if not state_is_tuple: zero_state = tf.concat([ zero_state[0][0], zero_state[0][1], zero_state[1], zero_state[2] ], 1) inputs = tf.random_uniform((batch_size, num_units), 0.0, 1.0, seed=seed + 5) output, state = cell(inputs, zero_state) # This is legacy behavior to preserve the test. Weight # sharing no longer works by creating a new RNNCell in the # same variable scope; so here we store the scope of the # first RNNCell for reuse above. if rnn_scope is None: rnn_scope = (cell._scope, lstm_cell._scope) # pylint: disable=protected-access if state_is_tuple: state = tf.concat( [state[0][0], state[0][1], state[1], state[2]], 1) sess.run(tf.global_variables_initializer()) self.assertAllClose(sess.run(output), expected_output) self.assertAllClose(sess.run(state), expected_state)