def testBasicRNNCell(self):
   with self.test_session() as sess:
     with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
       x = tf.zeros([1, 2])
       m = tf.zeros([1, 2])
       my_cell = functools.partial(basic_rnn_cell, num_units=2)
       # pylint: disable=protected-access
       g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
       # pylint: enable=protected-access
       sess.run([tf.global_variables_initializer()])
       res = sess.run([g], {x.name: np.array([[1., 1.]]),
                            m.name: np.array([[0.1, 0.1]])})
       self.assertEqual(res[0].shape, (1, 2))
Example #2
0
 def testBasicRNNCell(self):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros([1, 2])
       m = array_ops.zeros([1, 2])
       my_cell = functools.partial(basic_rnn_cell, num_units=2)
       # pylint: disable=protected-access
       g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
       # pylint: enable=protected-access
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run(
           [g], {x.name: np.array([[1., 1.]]),
                 m.name: np.array([[0.1, 0.1]])})
       self.assertEqual(res[0].shape, (1, 2))
 def testBasicRNNCellMatch(self):
   batch_size = 32
   input_size = 100
   num_units = 10
   with self.test_session() as sess:
     with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
       inputs = tf.random_uniform((batch_size, input_size))
       _, initial_state = basic_rnn_cell(inputs, None, num_units)
       my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
       # pylint: disable=protected-access
       slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
       # pylint: enable=protected-access
       slim_outputs, slim_state = slim_cell(inputs, initial_state)
       rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
       tf.get_variable_scope().reuse_variables()
       outputs, state = rnn_cell(inputs, initial_state)
       self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
       self.assertEqual(slim_state.get_shape(), state.get_shape())
       sess.run([tf.global_variables_initializer()])
       res = sess.run([slim_outputs, slim_state, outputs, state])
       self.assertAllClose(res[0], res[2])
       self.assertAllClose(res[1], res[3])
Example #4
0
 def testBasicRNNCellMatch(self):
   batch_size = 32
   input_size = 100
   num_units = 10
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       inputs = random_ops.random_uniform((batch_size, input_size))
       _, initial_state = basic_rnn_cell(inputs, None, num_units)
       rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
       outputs, state = rnn_cell(inputs, initial_state)
       variable_scope.get_variable_scope().reuse_variables()
       my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
       # pylint: disable=protected-access
       slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
       # pylint: enable=protected-access
       slim_outputs, slim_state = slim_cell(inputs, initial_state)
       self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
       self.assertEqual(slim_state.get_shape(), state.get_shape())
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([slim_outputs, slim_state, outputs, state])
       self.assertAllClose(res[0], res[2])
       self.assertAllClose(res[1], res[3])