예제 #1
0
    def test_swap_right_and_left_nodes(self):
        node_1 = BinaryTreeNode('*')
        node_1.add_left('A')
        node_1.add_right('B')
        tree_1 = BinaryTree(node_1)

        node_2 = BinaryTreeNode('+')
        node_2.add_left('C')
        node_2.add_right('D')
        tree_2 = BinaryTree(node_2)

        a = node_1.right
        b = node_2.left

        # should be
        #     *         +
        #    / \       / \
        #   A   C  ,  B   D
        SubtreeExchangeRecombinatorBase._swap_subtrees(a, b, tree_1, tree_2)

        root_1 = tree_1.root
        self.check_root(root_1, '*', 'A', 'C')
        self.check_leaf(root_1.left, 'A', '*')
        self.check_leaf(root_1.right, 'C', '*')

        root_2 = tree_2.root
        self.check_root(root_2, '+', 'B', 'D')
        self.check_leaf(root_2.left, 'B', '+')
        self.check_leaf(root_2.right, 'D', '+')
예제 #2
0
    def test_swap_stump_and_node(self):
        node_1 = BinaryTreeNode('*')
        node_1.add_left('A')
        node_1.add_right('B')
        tree_1 = BinaryTree(node_1)

        node_2 = BinaryTreeNode('+')
        node_2.add_left('C')
        node_2.add_right('D')
        tree_2 = BinaryTree(node_2)

        a = node_1.left
        b = node_2
        # should be
        #     *
        #    / \
        #   +   B   ,  A
        #  / \
        # C   D
        SubtreeExchangeRecombinatorBase._swap_subtrees(a, b, tree_1, tree_2)

        root_1 = tree_1.root
        self.check_root(root_1, '*', '+', 'B')
        self._check_node(root_1.left, '+', 'C', 'D', '*')
        self.check_leaf(root_1.right, 'B', '*')
        self.check_leaf(root_1.left.left, 'C', '+')
        self.check_leaf(root_1.left.right, 'D', '+')

        root_2 = tree_2.root
        self.check_stump(root_2, 'A')
예제 #3
0
    def test_crossover_trees_roots_selected(self):
        root_1 = BinaryTreeNode('*')
        root_1.add_left('B')
        right = root_1.add_right('+')
        right.add_left('D')
        rr = right.add_right('*')
        rr.add_left('F')
        rr.add_right('G')
        tree_1 = BinaryTree(root_1)

        root_2 = BinaryTreeNode('+')
        left = root_2.add_left('+')
        right = root_2.add_right('*')
        left.add_left('K')
        left.add_right('L')
        right.add_right('M')
        right.add_left('N')
        tree_2 = BinaryTree(root_2)

        parents = [tree_1, tree_2]

        self.recombinator.select_node_pair = MagicMock()
        self.recombinator.select_node_pair.return_value = (root_1, root_2)
        result_1, result_2 = self.recombinator.crossover(parents)

        self.assertIsInstance(result_1, BinaryTree)
        self.assertIsInstance(result_2, BinaryTree)

        self.recombinator.select_node_pair.assert_called_once()

        self.assertEqual(result_1, tree_1)
        self.assertEqual(result_2, tree_2)
예제 #4
0
    def test_swap_complex_trees(self):
        node_1 = BinaryTreeNode('*')
        node_1.add_left('A')
        right = node_1.add_right('B')
        right.add_right('R')
        tree_1 = BinaryTree(node_1)

        node_2 = BinaryTreeNode('+')
        left = node_2.add_left('C')
        node_2.add_right('D')
        left.add_left('L')
        tree_2 = BinaryTree(node_2)

        a = node_1.right
        b = node_2.left
        # should be
        #     *           +
        #    / \         / \
        #   A   C   ,   B   D
        #      /         \
        #     L           R
        SubtreeExchangeRecombinatorBase._swap_subtrees(a, b, tree_1, tree_2)

        root_1 = tree_1.root
        self.check_root(root_1, '*', 'A', 'C')
        self.check_leaf(root_1.left, 'A', '*')
        self._check_node(root_1.right, 'C', 'L', None, '*')
        self.check_leaf(root_1.right.left, 'L', 'C')

        root_2 = tree_2.root
        self.check_root(root_2, '+', 'B', 'D')
        self._check_node(root_2.left, 'B', None, 'R', '+')
        self.check_leaf(root_2.right, 'D', '+')
        self.check_leaf(root_2.left.right, 'R', 'B')
예제 #5
0
    def test_crossover_roots(self):
        root = BinaryTreeNode('*')
        root.add_left('B')
        right = root.add_right('+')
        right.add_left('D')
        rr = right.add_right('*')
        rr.add_left('F')
        rr.add_right('G')
        tree_1 = BinaryTree(root)

        root = BinaryTreeNode('+')
        left = root.add_left('+')
        root.add_right('J')
        left.add_left('K')
        left.add_right('L')
        tree_2 = BinaryTree(root)

        parents = [tree_1, tree_2]

        result_1, result_2 = self.recombinator.crossover(parents)

        self.assertIsInstance(result_1, BinaryTree)
        self.assertIsInstance(result_2, BinaryTree)

        self.assertEqual(result_1, tree_1)
        self.assertEqual(result_2, tree_2)
예제 #6
0
    def test_postfix_tokens(self):
        tree = BinaryTree()
        root = BinaryTreeNode('*')
        tree.root = root

        left = root.add_left('+')
        right = root.add_right('+')
        left.add_left('A')
        left.add_right('B')
        right.add_left('C')
        right.add_right('D')

        tokens = ['A', 'B', 'C', tree.root.label, '+', 'D', '+']
        result = tree.postfix_tokens()
        self.assertCountEqual(result, tokens)

        tree = BinaryTree()
        root = BinaryTreeNode('+')
        tree.root = root

        left = root.add_left('+')
        right = root.add_right('+')
        left.add_left('A')
        left.add_right('B')
        right.add_left('C')
        right.add_right('D')

        tokens = ['A', 'B', '+', 'C', '+', 'D', '+']
        result = tree.postfix_tokens()
        self.assertCountEqual(result, tokens)
예제 #7
0
    def test_structural_hamming_dist_complex_trees(self):
        #    tree 1
        #       *
        #      / \
        #    10   20
        #   /
        # 40
        root_1 = BinaryTreeNode('*')
        left = root_1.add_left(10)
        root_1.add_right(20)
        left.add_left(40)
        tree_1 = BinaryTree(root_1)

        #    tree 2
        #       +
        #      / \
        #    10   20
        #   /  \
        # 50   40
        root_2 = BinaryTreeNode('+')
        left = root_2.add_left(10)
        root_2.add_right(20)
        left.add_right(40)
        left.add_left(50)
        tree_2 = BinaryTree(root_2)

        result = structural_hamming_dist(tree_1, tree_2)
        self.assertEqual(2 / 3, result)
예제 #8
0
    def test_swap_nodes_with_children(self):
        node_1 = BinaryTreeNode('*')
        node_1.add_left('A')
        node_1.add_right('B')
        tree_1 = BinaryTree(node_1)

        node_2 = BinaryTreeNode('+')
        node_2.add_left('C')
        node_2.add_right('D')
        tree_2 = BinaryTree(node_2)

        a = node_1
        b = node_2
        # should be
        #     +         *
        #    / \       / \
        #   C   D  ,  A   B
        SubtreeExchangeRecombinatorBase._swap_subtrees(a, b, tree_1, tree_2)

        root_1 = tree_1.root
        self.check_root(root_1, '*', 'A', 'B')
        self.check_leaf(root_1.left, 'A', '*')
        self.check_leaf(root_1.right, 'B', '*')

        root_2 = tree_2.root
        self.check_root(root_2, '+', 'C', 'D')
        self.check_leaf(root_2.left, 'C', '+')
        self.check_leaf(root_2.right, 'D', '+')
예제 #9
0
class TestSubTreeExchangeMutator(TestCase):

    def setUp(self):
        self.tree = BinaryTree()
        self.root = BinaryTreeNode('*')
        self.tree.root = self.root
        self.root.add_left('A')
        self.root.add_right('B')

    def test_max_depth(self):
        self.assertRaises(ValueError, SubTreeExchangeMutator, max_depth=-2, binary_tree_node_cls=BinaryTreeNode)

    def test__mutate_subtree_exchange(self):
        max_depth = 2
        tree_gen = GrowGenerator(max_depth)

        result = SubTreeExchangeMutator._mutate_subtree_exchange(['+', '*'], [1, 2, 3], self.tree, tree_gen)
        self.assertIsInstance(result, BinaryTree)
        max_height = max_depth + 1
        initial_height = self.tree.height()
        final_height = result.height()
        self.assertLessEqual(final_height, initial_height + max_height)

    def test__swap_mut_subtree(self):
        random_tree = BinaryTree()
        left = random_tree.root = BinaryTreeNode('*')
        ll = random_tree.root.add_left('C')
        lr = random_tree.root.add_right('D')

        r = 0  # A
        result = SubTreeExchangeMutator._swap_mut_subtree(self.tree, r, random_tree)
        self.assertIsInstance(result, BinaryTree)
        self.assertEqual(result.height(), 3)
        self.assertEqual(self.tree.root.left, left)
        self.assertEqual(self.tree.root.left.left, ll)
        self.assertEqual(self.tree.root.left.right, lr)

    def test_to_dict(self):
        mutator = SubTreeExchangeMutator(4, BinaryTreeNode)
        actual = mutator.to_dict()
        self.assertIsInstance(actual, dict)
        self.assertEqual("src.evalg.genprog.mutation", actual["__module__"])
        self.assertEqual("SubTreeExchangeMutator", actual["__class__"])
        self.assertEqual("src.evalg.encoding", actual["binary_tree_node_module_name"])
        self.assertEqual("BinaryTreeNode", actual["binary_tree_node_cls_name"])
        self.assertEqual(mutator.max_depth, actual["max_depth"])

    def test_from_dict(self):
        test_cases = (SubTreeExchangeMutator, TreeMutator, Serializable)
        for cls in test_cases:
            with self.subTest(name=cls.__name__):
                mutator = SubTreeExchangeMutator(4, BinaryTreeNode)
                actual = cls.from_dict(mutator.to_dict())
                self.assertIsInstance(actual, SubTreeExchangeMutator)
                self.assertEqual(BinaryTreeNode, actual.binary_tree_node_cls)
                self.assertEqual(mutator.max_depth, actual.max_depth)
예제 #10
0
 def test_select_node_pair_one_pair(self):
     node_1 = BinaryTreeNode('*')
     node_1.add_left('A')
     node_1.add_right('B')
     node_2 = BinaryTreeNode('+')
     node_2.add_left('C')
     node_2.add_right('D')
     common_region = [(node_1, node_2)]
     result = self.recombinator.select_node_pair(common_region)
     self.assertIsNone(result)
예제 #11
0
 def test_select_node_pair_same_operator(self):
     node_1 = BinaryTreeNode('*')
     node_1.add_left('A')
     node_1.add_right('B')
     node_2 = BinaryTreeNode('*')
     node_2.add_left('C')
     node_2.add_right('D')
     node_3 = BinaryTreeNode('C')
     node_4 = BinaryTreeNode('D')
     common_region = [(node_1, node_2), (node_3, node_4)]
     result = self.recombinator.select_node_pair(common_region)
     self.assertIn(result, common_region)
예제 #12
0
    def test_structural_hamming_dist_small_trees(self):
        root_1 = BinaryTreeNode('*')
        root_1.add_left(10)
        root_1.add_right(20)
        tree_1 = BinaryTree(root_1)

        root_2 = BinaryTreeNode('+')
        root_2.add_left(10)
        root_2.add_right(30)
        tree_2 = BinaryTree(root_2)

        result = structural_hamming_dist(tree_1, tree_2)
        self.assertEqual(2 / 3, result)
예제 #13
0
    def test_select_node_pair_t_prob_0(self):
        self.recombinator.t_prob = 0

        node_1 = BinaryTreeNode('*')
        node_1.add_left('A')
        node_1.add_right('B')
        node_2 = BinaryTreeNode('+')
        node_2.add_left('C')
        node_2.add_right('D')
        node_3 = BinaryTreeNode('C')
        node_4 = BinaryTreeNode('D')
        common_region = [(node_1, node_2), (node_3, node_4)]
        result = self.recombinator.select_node_pair(common_region)
        self.assertEqual(result, common_region[1])
예제 #14
0
    def test_iter(self):
        root = BinaryTreeNode('*')
        self.assertEqual(root.height(), 1)

        left = root.add_left(10)
        self.assertEqual(root.height(), 2)
        right = root.add_right(20)
        self.assertEqual(root.height(), 2)

        ll = left.add_left(40)
        self.assertEqual(root.height(), 3)
        left.add_right(50)
        self.assertEqual(root.height(), 3)
        right.add_left(60)
        self.assertEqual(root.height(), 3)
        right.add_right(70)
        self.assertEqual(root.height(), 3)

        ll.add_left(80)
        self.assertEqual(root.height(), 4)

        result = []
        for value in root:
            self.assertIn(value, root)
            result.append(value)
        self.assertEqual(len(result), 8)
예제 #15
0
    def test_crossover_leaves(self):
        root_1 = BinaryTreeNode('*')
        root_1.add_left('B')
        right = root_1.add_right('+')
        right.add_left('D')
        rr = right.add_right('*')
        rr.add_left('F')
        rr.add_right('G')
        tree_1 = BinaryTree(root_1)

        root_2 = BinaryTreeNode('+')
        left = root_2.add_left('+')
        right = root_2.add_right('*')
        left.add_left('K')
        left.add_right('L')
        right.add_right('M')
        right.add_left('N')
        tree_2 = BinaryTree(root_2)

        parents = [tree_1, tree_2]

        self.recombinator.select_node_pair = MagicMock()
        self.recombinator.select_node_pair.return_value = (root_1.right.left, root_2.right.left)
        result_1, result_2 = self.recombinator.crossover(parents)

        self.assertIsInstance(result_1, BinaryTree)
        self.assertIsInstance(result_2, BinaryTree)

        self.recombinator.select_node_pair.assert_called_once()

        self.check_root(result_1.root, '*', 'B', '+')
        self.check_leaf(result_1.root.left, 'B', '*')
        self._check_node(result_1.root.right, '+', 'N', '*', '*')
        self.check_leaf(result_1.root.right.left, 'N', '+')
        self._check_node(result_1.root.right.right, '*', 'F', 'G', '+')
        self.check_leaf(result_1.root.right.right.left, 'F', '*')
        self.check_leaf(result_1.root.right.right.right, 'G', '*')

        self.check_root(result_2.root, '+', '+', '*')
        self._check_node(result_2.root.left, '+', 'K', 'L', '+')
        self._check_node(result_2.root.right, '*', 'D', 'M', '+')
        self.check_leaf(result_2.root.left.left, 'K', '+')
        self.check_leaf(result_2.root.left.right, 'L', '+')
        self.check_leaf(result_2.root.right.left, 'D', '*')
        self.check_leaf(result_2.root.right.right, 'M', '*')
예제 #16
0
    def test_get_common_region(self):
        root_1 = BinaryTreeNode('*')
        root_1.add_left('B')
        right = root_1.add_right('+')
        right.add_left('D')
        rr = right.add_right('*')
        rr.add_left('F')
        rr.add_right('G')

        root_2 = BinaryTreeNode('+')
        left = root_2.add_left('+')
        right = root_2.add_right('*')
        left.add_left('K')
        left.add_right('L')
        right.add_right('M')
        right.add_left('N')

        result = self.recombinator.get_common_region(root_1, root_2)
        self.assertListEqual(result, [(root_1, root_2), (root_1.right, root_2.right),
                                      (root_1.right.left, root_2.right.left)])
예제 #17
0
class TestTreePointMutator(TestCase):

    def setUp(self):
        self.tree = BinaryTree()
        self.root = BinaryTreeNode('*')
        self.tree.root = self.root
        self.root.add_left('A')
        self.root.add_right('B')
        np.random.seed(42)

    def test_mutate(self):
        mutator = TreePointMutator()
        tree = mutator.mutate(['+', '*'], ['A', 'B', 'C', 'D'], self.tree)
        self.assertEqual(tree.root.label, '+')
        self.assertIsInstance(tree, BinaryTree)

    def test_to_dict(self):
        mutator = TreePointMutator(BinaryTreeNode)
        actual = mutator.to_dict()
        self.assertIsInstance(actual, dict)
        self.assertEqual("src.evalg.genprog.mutation", actual["__module__"])
        self.assertEqual("TreePointMutator", actual["__class__"])
        self.assertEqual("src.evalg.encoding", actual["binary_tree_node_module_name"])
        self.assertEqual("BinaryTreeNode", actual["binary_tree_node_cls_name"])

    def test_from_dict(self):
        test_cases = (TreePointMutator, TreeMutator, Serializable)
        for cls in test_cases:
            with self.subTest(name=cls.__name__):
                mutator = TreePointMutator(BinaryTreeNode)
                actual = cls.from_dict(mutator.to_dict())
                self.assertIsInstance(actual, TreePointMutator)
                self.assertEqual(BinaryTreeNode, actual.binary_tree_node_cls)

    def tearDown(self):
        # reset random seed
        np.random.seed()
예제 #18
0
class TestHalfAndHalfMutator(TestCase):

    def setUp(self):
        self.tree = BinaryTree()
        self.root = BinaryTreeNode('*')
        self.tree.root = self.root
        self.root.add_left('A')
        self.root.add_right('B')

    def test_mutate(self):
        individual = self.tree
        operands = ['A', 'B', 'C']
        mutator = HalfAndHalfMutator(max_depth=2)
        result = mutator.mutate(['+', '*'], operands, individual)
        self.assertIsInstance(result, BinaryTree)
        max_height = mutator.max_depth + 1
        self.assertLessEqual(result.height(), self.tree.height() + max_height)

    def test_to_dict(self):
        mutator = HalfAndHalfMutator(4, BinaryTreeNode)
        actual = mutator.to_dict()
        self.assertIsInstance(actual, dict)
        self.assertEqual("src.evalg.genprog.mutation", actual["__module__"])
        self.assertEqual("HalfAndHalfMutator", actual["__class__"])
        self.assertEqual("src.evalg.encoding", actual["binary_tree_node_module_name"])
        self.assertEqual("BinaryTreeNode", actual["binary_tree_node_cls_name"])
        self.assertEqual(mutator.max_depth, actual["max_depth"])

    def test_from_dict(self):
        test_cases = (HalfAndHalfMutator, SubTreeExchangeMutator, TreeMutator, Serializable)
        for cls in test_cases:
            with self.subTest(name=cls.__name__):
                mutator = HalfAndHalfMutator(4, BinaryTreeNode)
                actual = cls.from_dict(mutator.to_dict())
                self.assertIsInstance(actual, HalfAndHalfMutator)
                self.assertEqual(BinaryTreeNode, actual.binary_tree_node_cls)
                self.assertEqual(mutator.max_depth, actual.max_depth)
예제 #19
0
    def test_contains(self):
        root = BinaryTreeNode('*')
        self.assertIn('*', root)

        left = root.add_left(10)
        self.assertIn('*', root)
        self.assertIn(10, root)
        self.assertIn(10, left)
        self.assertIn(10, root.left)

        right = root.add_right(20)
        self.assertIn('*', root)
        self.assertIn(20, right)
        self.assertIn(20, right)
        self.assertIn(20, root.right)
예제 #20
0
    def test_len(self):
        root = BinaryTreeNode('*')
        self.assertEqual(len(root), 1)

        left = root.add_left(10)
        self.assertEqual(len(root), 2)
        self.assertEqual(len(left), 1)

        right = root.add_right(20)
        self.assertEqual(len(root), 3)
        self.assertEqual(len(left), 1)
        self.assertEqual(len(right), 1)

        ll = left.add_left(40)
        self.assertEqual(len(root), 4)
        self.assertEqual(len(left), 2)
        self.assertEqual(len(right), 1)
        self.assertEqual(len(ll), 1)

        lr = left.add_right(50)
        self.assertEqual(len(root), 5)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 1)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)

        rl = right.add_left(60)
        self.assertEqual(len(root), 6)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 2)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)
        self.assertEqual(len(rl), 1)

        rr = right.add_right(70)
        self.assertEqual(len(root), 7)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 3)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)
        self.assertEqual(len(rl), 1)
        self.assertEqual(len(rr), 1)
예제 #21
0
    def test_height(self):
        root = BinaryTreeNode('*')
        self.assertEqual(root.height(), 1)

        left = root.add_left(10)
        self.assertEqual(root.height(), 2)
        right = root.add_right(20)
        self.assertEqual(root.height(), 2)

        ll = left.add_left(40)
        self.assertEqual(root.height(), 3)
        left.add_right(50)
        self.assertEqual(root.height(), 3)
        right.add_left(60)
        self.assertEqual(root.height(), 3)
        right.add_right(70)
        self.assertEqual(root.height(), 3)

        ll.add_left(80)
        self.assertEqual(root.height(), 4)
예제 #22
0
class TestBinaryTreeNode(TestCase):
    def setUp(self):
        self.root_val = 'Parent Value'
        self.root = BinaryTreeNode(self.root_val)

        self.left_child_val = 42
        self.right_child_val = 13

    def test_has_left_child(self):
        self.assertFalse(self.root.has_left_child())
        self.root.add_right(self.right_child_val)
        self.assertFalse(self.root.has_left_child())
        self.root.add_left(self.left_child_val)
        self.assertTrue(self.root.has_left_child())

    def test_has_right_child(self):
        self.assertFalse(self.root.has_right_child())
        self.root.add_left(self.left_child_val)
        self.assertFalse(self.root.has_right_child())
        self.root.add_right(self.right_child_val)
        self.assertTrue(self.root.has_right_child())

    def test_has_parent(self):
        self.assertFalse(self.root.has_parent())
        right = self.root.add_right(self.right_child_val)
        self.assertTrue(right.has_parent())
        self.assertTrue(self.root.right.has_parent())
        self.assertFalse(self.root.has_parent())
        left = self.root.add_left(self.left_child_val)
        self.assertTrue(left.has_parent())
        self.assertTrue(self.root.left.has_parent())
        self.assertFalse(self.root.has_parent())

    def test_is_left_child(self):
        self.assertRaises(AttributeError, self.root.is_left_child)
        left = self.root.add_left(self.left_child_val)
        self.assertTrue(left.is_left_child())
        self.assertTrue(self.root.left.is_left_child())
        right = self.root.add_right(self.right_child_val)
        self.assertFalse(right.is_left_child())
        self.assertFalse(self.root.right.is_left_child())

    def test_is_right_child(self):
        self.assertRaises(AttributeError, self.root.is_right_child)
        left = self.root.add_left(self.left_child_val)
        self.assertFalse(left.is_right_child())
        self.assertFalse(self.root.left.is_right_child())
        right = self.root.add_right(self.right_child_val)
        self.assertTrue(right.is_right_child())
        self.assertTrue(self.root.right.is_right_child())

    def test_is_root(self):
        self.assertTrue(self.root.is_root())
        left = self.root.add_left(self.left_child_val)
        self.assertFalse(left.is_root())
        right = self.root.add_right(self.right_child_val)
        self.assertFalse(right.is_root())

    def test_is_leaf(self):
        self.assertTrue(self.root.is_leaf())
        left = self.root.add_left(self.left_child_val)
        self.assertTrue(left.is_leaf())
        self.assertFalse(self.root.is_leaf())
        right = self.root.add_right(self.right_child_val)
        self.assertTrue(right.is_leaf())
        self.assertFalse(self.root.is_leaf())

    def test_add_left(self):
        result = self.root.add_left(self.left_child_val)
        self.assertEqual(result.parent, self.root)
        self.assertEqual(result.parent.value, self.root_val)
        self.assertEqual(result.parent.left, result)
        self.assertEqual(result.parent.left.value, self.left_child_val)

    def test_add_right(self):
        result = self.root.add_right(self.right_child_val)
        self.assertEqual(result.parent, self.root)
        self.assertEqual(result.parent.value, self.root_val)
        self.assertEqual(result.parent.right, result)
        self.assertEqual(result.parent.right.value, self.right_child_val)

    def test_create_graph(self):
        result = self.root.create_graph()
        self.assertIsInstance(result, Digraph)

    def test_height(self):
        root = BinaryTreeNode('*')
        self.assertEqual(root.height(), 1)

        left = root.add_left(10)
        self.assertEqual(root.height(), 2)
        right = root.add_right(20)
        self.assertEqual(root.height(), 2)

        ll = left.add_left(40)
        self.assertEqual(root.height(), 3)
        left.add_right(50)
        self.assertEqual(root.height(), 3)
        right.add_left(60)
        self.assertEqual(root.height(), 3)
        right.add_right(70)
        self.assertEqual(root.height(), 3)

        ll.add_left(80)
        self.assertEqual(root.height(), 4)

    def test_contains(self):
        root = BinaryTreeNode('*')
        self.assertIn('*', root)

        left = root.add_left(10)
        self.assertIn('*', root)
        self.assertIn(10, root)
        self.assertIn(10, left)
        self.assertIn(10, root.left)

        right = root.add_right(20)
        self.assertIn('*', root)
        self.assertIn(20, right)
        self.assertIn(20, right)
        self.assertIn(20, root.right)

    def test_iter(self):
        root = BinaryTreeNode('*')
        self.assertEqual(root.height(), 1)

        left = root.add_left(10)
        self.assertEqual(root.height(), 2)
        right = root.add_right(20)
        self.assertEqual(root.height(), 2)

        ll = left.add_left(40)
        self.assertEqual(root.height(), 3)
        left.add_right(50)
        self.assertEqual(root.height(), 3)
        right.add_left(60)
        self.assertEqual(root.height(), 3)
        right.add_right(70)
        self.assertEqual(root.height(), 3)

        ll.add_left(80)
        self.assertEqual(root.height(), 4)

        result = []
        for value in root:
            self.assertIn(value, root)
            result.append(value)
        self.assertEqual(len(result), 8)

    def test_len(self):
        root = BinaryTreeNode('*')
        self.assertEqual(len(root), 1)

        left = root.add_left(10)
        self.assertEqual(len(root), 2)
        self.assertEqual(len(left), 1)

        right = root.add_right(20)
        self.assertEqual(len(root), 3)
        self.assertEqual(len(left), 1)
        self.assertEqual(len(right), 1)

        ll = left.add_left(40)
        self.assertEqual(len(root), 4)
        self.assertEqual(len(left), 2)
        self.assertEqual(len(right), 1)
        self.assertEqual(len(ll), 1)

        lr = left.add_right(50)
        self.assertEqual(len(root), 5)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 1)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)

        rl = right.add_left(60)
        self.assertEqual(len(root), 6)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 2)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)
        self.assertEqual(len(rl), 1)

        rr = right.add_right(70)
        self.assertEqual(len(root), 7)
        self.assertEqual(len(left), 3)
        self.assertEqual(len(right), 3)
        self.assertEqual(len(ll), 1)
        self.assertEqual(len(lr), 1)
        self.assertEqual(len(rl), 1)
        self.assertEqual(len(rr), 1)
예제 #23
0
class TestBinaryTree(TestCase):
    def setUp(self):
        self.tree = BinaryTree()
        self.root = BinaryTreeNode('*')
        self.tree.root = self.root

    def test_root(self):
        tree = BinaryTree()
        with self.assertRaises(TypeError):
            tree.root = 'bad type'

    def test_create_graph(self):
        result = self.tree.create_graph()
        self.assertIsInstance(result, Digraph)

    def test_select_postorder(self):
        left = self.root.add_left(20)
        right = self.root.add_right(30)
        ll = left.add_left(40)
        lr = left.add_right(50)
        rl = right.add_left(60)
        rr = right.add_right(70)
        self.assertEqual(self.tree.select_postorder(0), ll)
        self.assertEqual(self.tree.select_postorder(1), lr)
        self.assertEqual(self.tree.select_postorder(2), left)
        self.assertEqual(self.tree.select_postorder(3), rl)
        self.assertEqual(self.tree.select_postorder(4), rr)
        self.assertEqual(self.tree.select_postorder(5), right)
        self.assertEqual(self.tree.select_postorder(6), self.root)

    def test_height(self):
        tree = BinaryTree()
        self.assertEqual(tree.height(), 0)

        tree.root = BinaryTreeNode('*')
        self.assertEqual(tree.height(), 1)

        left = tree.root.add_left(10)
        self.assertEqual(tree.height(), 2)
        right = tree.root.add_right(20)
        self.assertEqual(tree.height(), 2)

        ll = left.add_left(40)
        self.assertEqual(tree.height(), 3)
        left.add_right(50)
        self.assertEqual(tree.height(), 3)
        right.add_left(60)
        self.assertEqual(tree.height(), 3)
        right.add_right(70)
        self.assertEqual(tree.height(), 3)

        ll.add_left(80)
        self.assertEqual(tree.height(), 4)

    def test_infix_tokens(self):
        left = self.root.add_left('+')
        right = self.root.add_right('+')
        left.add_left('A')
        left.add_right('B')
        right.add_left('C')
        right.add_right('D')

        tokens = [
            '(', '(', 'A', '+', 'B', ')', self.tree.root.label, '(', 'C', '+',
            'D', ')', ')'
        ]
        result = self.tree.infix_tokens()
        self.assertCountEqual(result, tokens)

    def test_postfix_tokens(self):
        tree = BinaryTree()
        root = BinaryTreeNode('*')
        tree.root = root

        left = root.add_left('+')
        right = root.add_right('+')
        left.add_left('A')
        left.add_right('B')
        right.add_left('C')
        right.add_right('D')

        tokens = ['A', 'B', 'C', tree.root.label, '+', 'D', '+']
        result = tree.postfix_tokens()
        self.assertCountEqual(result, tokens)

        tree = BinaryTree()
        root = BinaryTreeNode('+')
        tree.root = root

        left = root.add_left('+')
        right = root.add_right('+')
        left.add_left('A')
        left.add_right('B')
        right.add_left('C')
        right.add_right('D')

        tokens = ['A', 'B', '+', 'C', '+', 'D', '+']
        result = tree.postfix_tokens()
        self.assertCountEqual(result, tokens)