def testBasicRNNCell(self): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 2]) m = tf.zeros([1, 2]) my_cell = functools.partial(basic_rnn_cell, num_units=2) # pylint: disable=protected-access g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([tf.global_variables_initializer()]) res = sess.run([g], {x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]])}) self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCell(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) my_cell = functools.partial(basic_rnn_cell, num_units=2) # pylint: disable=protected-access g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g], {x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]])}) self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self): batch_size = 32 input_size = 100 num_units = 10 with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inputs = tf.random_uniform((batch_size, input_size)) _, initial_state = basic_rnn_cell(inputs, None, num_units) my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) tf.get_variable_scope().reuse_variables() outputs, state = rnn_cell(inputs, initial_state) self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) self.assertEqual(slim_state.get_shape(), state.get_shape()) sess.run([tf.global_variables_initializer()]) res = sess.run([slim_outputs, slim_state, outputs, state]) self.assertAllClose(res[0], res[2]) self.assertAllClose(res[1], res[3])
def testBasicRNNCellMatch(self): batch_size = 32 input_size = 100 num_units = 10 with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inputs = random_ops.random_uniform((batch_size, input_size)) _, initial_state = basic_rnn_cell(inputs, None, num_units) rnn_cell = rnn_cell_impl.BasicRNNCell(num_units) outputs, state = rnn_cell(inputs, initial_state) variable_scope.get_variable_scope().reuse_variables() my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) self.assertEqual(slim_state.get_shape(), state.get_shape()) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([slim_outputs, slim_state, outputs, state]) self.assertAllClose(res[0], res[2]) self.assertAllClose(res[1], res[3])