Exemplo n.º 1
0
    def test_get_hidden_vector_summary(self):
        stack = state_stack.create(max_size=10, hidden_vector_size=4)
        stack = state_stack.push(stack, test_utils.constant(1),
                                 test_utils.constant([1, 1, 1, 1]))
        stack = state_stack.push(stack, test_utils.constant(2),
                                 test_utils.constant([2, 2, 2, 2]))
        stack = state_stack.push(stack, test_utils.constant(3),
                                 test_utils.constant([3, 3, 3, 3]))

        hidden_vector_summary = state_stack.get_hidden_vector_summary(stack)

        with tf.Session() as sess:
            self.assertEqual(sess.run(tf.shape(hidden_vector_summary)), 4)
            self.assertListEqual(
                sess.run(hidden_vector_summary).tolist(), [2, 2, 2, 2])
Exemplo n.º 2
0
    def test_is_empty_after_pop(self):
        stack = state_stack.create(max_size=10, hidden_vector_size=4)
        stack = state_stack.push(stack, test_utils.constant(1),
                                 test_utils.constant([1, 1, 1, 1]))
        stack = state_stack.push(stack, test_utils.constant(2),
                                 test_utils.constant([2, 2, 2, 2]))
        is_empty_false1 = state_stack.is_empty(stack)
        _, _, stack = state_stack.pop(stack)
        is_empty_false2 = state_stack.is_empty(stack)
        _, _, stack = state_stack.pop(stack)
        is_empty_true = state_stack.is_empty(stack)

        with tf.Session() as sess:
            self.assertTrue(sess.run(is_empty_true))
            self.assertFalse(sess.run(is_empty_false1))
            self.assertFalse(sess.run(is_empty_false2))
Exemplo n.º 3
0
    def test_push_pop(self):
        stack = state_stack.create(max_size=10, hidden_vector_size=4)
        stack = state_stack.push(stack, test_utils.constant(1),
                                 test_utils.constant([1, 1, 1, 1]))

        state, hidden_vector, stack = state_stack.pop(stack)

        with tf.Session() as sess:
            self.assertEqual(sess.run(state), 1)
            self.assertListEqual(
                sess.run(hidden_vector).tolist(), [1, 1, 1, 1])
Exemplo n.º 4
0
    def test_peek(self):
        stack = state_stack.create(max_size=10, hidden_vector_size=4)
        stack = state_stack.push(stack, test_utils.constant(1),
                                 test_utils.constant([1, 1, 1, 1]))

        state1, hidden_vector1, _ = state_stack.pop(stack)
        state2, hidden_vector2, _ = state_stack.pop(stack)

        with tf.Session() as sess:
            self.assertEqual(sess.run(state1), sess.run(state2))
            self.assertListEqual(
                sess.run(hidden_vector1).tolist(),
                sess.run(hidden_vector2).tolist())
Exemplo n.º 5
0
    def test_multiple_push_pop(self):
        stack = state_stack.create(max_size=10, hidden_vector_size=4)
        stack = state_stack.push(stack, test_utils.constant(1),
                                 test_utils.constant([1, 1, 1, 1]))
        stack = state_stack.push(stack, test_utils.constant(2),
                                 test_utils.constant([2, 2, 2, 2]))
        stack = state_stack.push(stack, test_utils.constant(3),
                                 test_utils.constant([3, 3, 3, 3]))

        state3, hidden_vector3, stack = state_stack.pop(stack)
        state2, hidden_vector2, stack = state_stack.pop(stack)
        state1, hidden_vector1, stack = state_stack.pop(stack)

        with tf.Session() as sess:
            self.assertEqual(sess.run(state1), 1)
            self.assertListEqual(
                sess.run(hidden_vector1).tolist(), [1, 1, 1, 1])
            self.assertEqual(sess.run(state2), 2)
            self.assertListEqual(
                sess.run(hidden_vector2).tolist(), [2, 2, 2, 2])
            self.assertEqual(sess.run(state3), 3)
            self.assertListEqual(
                sess.run(hidden_vector3).tolist(), [3, 3, 3, 3])