Exemplo n.º 1
0
    def test_calc_node_depth(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5, self.n6])

        self.assertEqual(node.calc_node_depth(self.n1), 2)
Exemplo n.º 2
0
def one_point(solution, func_bank):
    """
    Core function of one_point mutation.
    :param solution: solution object. solution which is applied mutation.
    :param func_bank: function bank object. function bank which is defined in problem.py.
    :return: solution object.
    """
    point = random.choice(node.get_all_node(solution.root))
    n_children = len(point.children)
    function_list = func_bank.get_function_list(n_children)
    if function_list is None:
        raise ValueError(
            "function bank must have {}'s function list, but it has no list.".
            format(n_children))

    if len(function_list) == 1:
        return solution

    candidate_id = random.sample(function_list, 2)
    if point.func_id != candidate_id[0]:
        node.set_id(point, candidate_id[0])
    else:
        node.set_id(point, candidate_id[1])

    return solution
Exemplo n.º 3
0
    def test_get_all_node(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5, self.n6])
        all_nodes = [self.n1, self.n2, self.n4, self.n5, self.n6, self.n3]

        self.assertEqual(node.get_all_node(self.n1), all_nodes)
Exemplo n.º 4
0
    def test_get_nonterminal_nodes(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5])
        node.set_children(self.n3, [self.n6])

        nonterminal_nodes = node.get_all_nonterminal_nodes(self.n1)
        expected_nonterminal_nodes = [self.n1, self.n2, self.n3]

        for n in nonterminal_nodes:
            self.assertTrue(n in expected_nonterminal_nodes)
Exemplo n.º 5
0
    def test_get_parent_node(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5, self.n6])
        self.assertEqual(node.get_parent_node(self.n1, self.n4), (0, self.n2))

        msg = 'There is no parent of root.'
        with self.assertRaises(ValueError, msg=msg):
            node.get_parent_node(self.n1, self.n1)

        msg = 'Invalid arguments: cannot find parent.'
        with self.assertRaises(ValueError, msg=msg):
            node.get_parent_node(node.Node(), node.Node())
Exemplo n.º 6
0
    def test__copy_nodes_along_graph(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5])
        node.set_children(self.n3, [self.n6])

        graph = [(0, self.n1), (1, self.n2)]
        pos, parent, root = node.copy_nodes_along_graph(graph)
        self.assertEqual(pos, 1)
        self.assertTrue(node.node_equal(parent, self.n2))
        self.assertTrue(node.node_equal(root, self.n1, as_tree=True))
        self.assertFalse(parent is self.n2)
        self.assertFalse(root is self.n1)
Exemplo n.º 7
0
        def new_node(parent, depth):
            current_node = node.Node()
            if self.t_prob > random.random() or depth == self.max_depth:
                func_id = random.choice(self.terminal_list)
            else:
                current_node.children = []
                func_id = random.choice(self.nonterminal_list)
                n_child = node.get_n_children(
                    func_id, self.func_bank.get_function_list())
                for _ in range(n_child):
                    new_node(current_node, depth + 1)

            node.set_id(current_node, func_id)

            if parent is not None:
                parent.children.append(current_node)
            else:
                return current_node
Exemplo n.º 8
0
    def test_solution_equal(self):
        na1 = node.Node(0)
        na2 = node.Node(1)
        na3 = node.Node(0)
        na4 = node.Node(1)
        na5 = node.Node(0)
        na6 = node.Node(1)

        node.set_id(na1, 0)
        node.set_id(na2, 1)
        node.set_children(na1, [na2, na3])
        node.set_children(na2, [na4, na5, na6])
        sa = solution.Solution(na1)
        self.assertTrue(solution.solution_equal(self.s1, sa, True))
        self.assertFalse(solution.solution_equal(self.s1, sa, False))
        na4 = node.Node(0)
        node.set_children(na2, [na4, na5, na6])
        self.assertFalse(solution.solution_equal(self.s1, sa, True))
Exemplo n.º 9
0
    def test_destructive_replace_node(self):
        na1 = node.Node(1)
        na2 = node.Node(0)

        node.set_id(na1, 0)
        node.set_id(na2, 1)
        node.set_children(na1, [na2])
        original_s1 = copy.deepcopy(self.s1)
        new_s1 = solution.replace_node(self.s1, self.n2, na1, destructive=True)

        expected_nodes = [node.Node(0), node.Node(0), node.Node(1), node.Node(0)]
        node.set_children(expected_nodes[0], [expected_nodes[1], expected_nodes[3]])
        node.set_children(expected_nodes[1], [expected_nodes[2]])

        self.assertTrue(node.node_equal(new_s1.root, expected_nodes[0], as_tree=True))
        self.assertEqual(new_s1.n_nodes, len(expected_nodes))
        self.assertEqual(new_s1.depth, 2)
        # Check whether the original solution is NOT protected.
        self.assertFalse(solution.solution_equal(original_s1, new_s1))
        self.assertTrue(self.s1 is new_s1)
Exemplo n.º 10
0
    def test_stop_improvement_with_shuffle(self):
        root = self.s1.root
        c1 = root.children[0]
        c2 = root.children[1]
        is_shuffle = True

        flg2 = False
        flg4 = False
        for i in range(100):
            improved = localsearch.stop_improvement(self.s1,
                                                    root,
                                                    self.problem,
                                                    is_shuffle=is_shuffle)
            if i == 0:
                self.assertEqual(improved, True)
            if root.func_id == 2:
                flg2 = True
                node.set_id(root, self.root_id)
                solution.set_previous_fitness(self.s1, self.root_id)
            elif root.func_id == 4:
                flg4 = True
                node.set_id(root, self.root_id)
                solution.set_previous_fitness(self.s1, self.root_id)

            if flg2 and flg4:
                break

        self.assertEqual(flg2 and flg4, True)

        improved = localsearch.stop_improvement(self.s1,
                                                c1,
                                                self.problem,
                                                is_shuffle=is_shuffle)
        self.assertEqual(improved, False)

        improved = localsearch.stop_improvement(self.s1,
                                                c2,
                                                self.problem,
                                                is_shuffle=is_shuffle)
        self.assertEqual(improved, False)
Exemplo n.º 11
0
def improve(target_solution, target_node, candidate_id, problem):
    """
    Core function for local search.
    Replace the old function with a new function and then revert it if fitness is not improved.
    :param target_solution: solution object. target solution of local search.
    :param target_node: node object. target node of the target solution.
    :param candidate_id: int. ID of candidate function for local search.
    :param problem: problem object. problem for calculation of fitness.
    :return: bool. if improvement is success, return True.
    """
    pre_id = target_node.func_id
    node.set_id(target_node, candidate_id)
    if target_solution.previous_fitness is None:  # if the solution does not have previous fitness, calculate it.
        pre_fit = problem.fitness(target_solution)
    else:
        pre_fit = target_solution.previous_fitness
    new_fit = problem.fitness(target_solution)  # check the fitness.
    if pre_fit >= new_fit:  # if it is not success, revert the function.
        node.set_id(target_node, pre_id)
        solution.set_previous_fitness(target_solution, pre_fit)
        return False
    else:
        return True
Exemplo n.º 12
0
def bihc(target_solution, node_list, fs_core):
    """
    Best improvement hill climber (BIHC) function.
    :param node_list: list of node object. candidate node list.
    :param fs_core: function. search function for a target node.
    :return: solution object.
    """
    pre_node = None
    while node_list:
        ori_fit = target_solution.previous_fitness
        best_fit = target_solution.previous_fitness
        best_node = None
        best_id = None
        for target_node in node_list:  # Try to find the best-improving target node and the function id.
            ori_id = target_node.func_id
            fs_core(target_solution, target_node)
            if best_fit < target_solution.previous_fitness:
                best_node = target_node
                best_id = target_node.func_id
            # the target solution is reverted to the original for the next iteration.
            node.set_id(target_node, ori_id)
            solution.set_previous_fitness(target_solution, ori_fit)
        # If there is no improvement, end this local search.
        if best_node is None:
            break

        # Otherwise, adopt the improvement to the solution.
        node.set_id(best_node, best_id)
        solution.set_previous_fitness(target_solution, best_fit)

        if pre_node is not None:
            node_list.append(pre_node)
        node_list.remove(best_node)  # remove the replaced node from candidate node list
        pre_node = best_node

    return target_solution
Exemplo n.º 13
0
 def make_node(func_id):
     new_node = node.Node()
     node.set_id(new_node, func_id)
     return new_node
Exemplo n.º 14
0
    def test_set_id(self):
        func_id = 0
        node.set_id(self.n1, func_id)

        self.assertEqual(self.n1.func_id, func_id)
Exemplo n.º 15
0
 def test_node_array_equal(self):
     self.assertTrue(
         node.node_array_equal([self.n1, self.n2], [self.n3, self.n4]))
     node.set_id(self.n4, 0)
     self.assertFalse(
         node.node_array_equal([self.n1, self.n2], [self.n3, self.n4]))
Exemplo n.º 16
0
    def test_node_equal(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 0)
        node.set_id(self.n3, 0)
        node.set_id(self.n4, 0)
        node.set_id(self.n5, 0)
        node.set_id(self.n6, 0)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n4, [self.n5, self.n6])

        self.assertTrue(node.node_equal(self.n1, self.n4))
        self.assertTrue(node.node_equal(self.n1, self.n4, as_tree=True))
        node.set_id(self.n2, 1)
        self.assertFalse(node.node_equal(self.n1, self.n3))
        self.assertFalse(node.node_equal(self.n2, self.n3))
        self.assertFalse(node.node_equal(self.n1, self.n4, as_tree=True))