def test_get_hidden_vector_summary(self): stack = state_stack.create(max_size=10, hidden_vector_size=4) stack = state_stack.push(stack, test_utils.constant(1), test_utils.constant([1, 1, 1, 1])) stack = state_stack.push(stack, test_utils.constant(2), test_utils.constant([2, 2, 2, 2])) stack = state_stack.push(stack, test_utils.constant(3), test_utils.constant([3, 3, 3, 3])) hidden_vector_summary = state_stack.get_hidden_vector_summary(stack) with tf.Session() as sess: self.assertEqual(sess.run(tf.shape(hidden_vector_summary)), 4) self.assertListEqual( sess.run(hidden_vector_summary).tolist(), [2, 2, 2, 2])
def test_is_empty_after_pop(self): stack = state_stack.create(max_size=10, hidden_vector_size=4) stack = state_stack.push(stack, test_utils.constant(1), test_utils.constant([1, 1, 1, 1])) stack = state_stack.push(stack, test_utils.constant(2), test_utils.constant([2, 2, 2, 2])) is_empty_false1 = state_stack.is_empty(stack) _, _, stack = state_stack.pop(stack) is_empty_false2 = state_stack.is_empty(stack) _, _, stack = state_stack.pop(stack) is_empty_true = state_stack.is_empty(stack) with tf.Session() as sess: self.assertTrue(sess.run(is_empty_true)) self.assertFalse(sess.run(is_empty_false1)) self.assertFalse(sess.run(is_empty_false2))
def test_push_pop(self): stack = state_stack.create(max_size=10, hidden_vector_size=4) stack = state_stack.push(stack, test_utils.constant(1), test_utils.constant([1, 1, 1, 1])) state, hidden_vector, stack = state_stack.pop(stack) with tf.Session() as sess: self.assertEqual(sess.run(state), 1) self.assertListEqual( sess.run(hidden_vector).tolist(), [1, 1, 1, 1])
def test_peek(self): stack = state_stack.create(max_size=10, hidden_vector_size=4) stack = state_stack.push(stack, test_utils.constant(1), test_utils.constant([1, 1, 1, 1])) state1, hidden_vector1, _ = state_stack.pop(stack) state2, hidden_vector2, _ = state_stack.pop(stack) with tf.Session() as sess: self.assertEqual(sess.run(state1), sess.run(state2)) self.assertListEqual( sess.run(hidden_vector1).tolist(), sess.run(hidden_vector2).tolist())
def test_multiple_push_pop(self): stack = state_stack.create(max_size=10, hidden_vector_size=4) stack = state_stack.push(stack, test_utils.constant(1), test_utils.constant([1, 1, 1, 1])) stack = state_stack.push(stack, test_utils.constant(2), test_utils.constant([2, 2, 2, 2])) stack = state_stack.push(stack, test_utils.constant(3), test_utils.constant([3, 3, 3, 3])) state3, hidden_vector3, stack = state_stack.pop(stack) state2, hidden_vector2, stack = state_stack.pop(stack) state1, hidden_vector1, stack = state_stack.pop(stack) with tf.Session() as sess: self.assertEqual(sess.run(state1), 1) self.assertListEqual( sess.run(hidden_vector1).tolist(), [1, 1, 1, 1]) self.assertEqual(sess.run(state2), 2) self.assertListEqual( sess.run(hidden_vector2).tolist(), [2, 2, 2, 2]) self.assertEqual(sess.run(state3), 3) self.assertListEqual( sess.run(hidden_vector3).tolist(), [3, 3, 3, 3])