Example #1
0
def replace_node(solution, replaced_node, new_node, destructive=True):
    """
    Replace a node in a solution by another node.
    :param solution: class Solution.
    :param replaced_node: class Node. A node to be replaced in the solution.
    :param new_node: class Node. A node set to replaced point in the solution.
    :param destructive: bool. If true, solution is replaced, keeping its object.
    Otherwise, new solution instance is created, protecting original solution.
    :return solution: class Solution.
    """

    # TODO: Type check if ``solution'' is Solution and ``nodes'' are Node
    # If replaced_node is root node
    if solution.root is replaced_node:
        if destructive:
            solution.root = new_node
        else:
            solution = Solution(new_node)

        set_solution_n_nodes(solution)
        set_solution_depth(solution)

        return solution

    # Otherwise
    try:
        graph = get_graph_to_target(solution.root, replaced_node)
    except ValueError:
        msg = 'replaced_node must be in a tree of a solution.'
        raise ValueError(msg)

    # Obtain terms to calculate the number of the nodes and the depth.
    point_depth = len(graph)
    n_rpl_nodes = len(get_all_node(replaced_node))
    n_new_nodes = len(get_all_node(new_node))
    rpl_node_depth = calc_node_depth(replaced_node)
    new_node_depth = calc_node_depth(new_node)

    # Core calculation and setting of depth and the number of nodes.
    depth = max(point_depth + new_node_depth - rpl_node_depth, solution.depth)
    n_nodes = solution.n_nodes + n_new_nodes - n_rpl_nodes

    # Obtain the replaced point
    if destructive:
        idx, parent = graph[-1]
    else:
        idx, parent, root = copy_nodes_along_graph(graph)
        solution = Solution(root)

    # Replace the replaced_node by new_node
    parent.children[idx] = new_node

    # Set the depth and n_nodes based on the results.
    set_solution_depth(solution, depth)
    set_solution_n_nodes(solution, n_nodes)

    return solution
Example #2
0
def one_point(solution, func_bank):
    """
    Core function of one_point mutation.
    :param solution: solution object. solution which is applied mutation.
    :param func_bank: function bank object. function bank which is defined in problem.py.
    :return: solution object.
    """
    point = random.choice(node.get_all_node(solution.root))
    n_children = len(point.children)
    function_list = func_bank.get_function_list(n_children)
    if function_list is None:
        raise ValueError(
            "function bank must have {}'s function list, but it has no list.".
            format(n_children))

    if len(function_list) == 1:
        return solution

    candidate_id = random.sample(function_list, 2)
    if point.func_id != candidate_id[0]:
        node.set_id(point, candidate_id[0])
    else:
        node.set_id(point, candidate_id[1])

    return solution
Example #3
0
    def test_get_all_node(self):
        node.set_id(self.n1, 0)
        node.set_id(self.n2, 1)
        node.set_children(self.n1, [self.n2, self.n3])
        node.set_children(self.n2, [self.n4, self.n5, self.n6])
        all_nodes = [self.n1, self.n2, self.n4, self.n5, self.n6, self.n3]

        self.assertEqual(node.get_all_node(self.n1), all_nodes)
Example #4
0
    def test_destructive_crossover_core(self):
        s1_nodes = node.get_all_node(self.parents[0].root)
        s2_nodes = node.get_all_node(self.parents[1].root)
        points = [s1_nodes[0], s2_nodes[2]]

        expected_nodes1 = [node.Node(1)]
        expected_nodes2 = [node.Node(0), node.Node(1), node.Node(0), node.Node(1), node.Node(1),
                           node.Node(0), node.Node(0), node.Node(0), node.Node(0)]

        node.set_children(expected_nodes2[0], [expected_nodes2[1], expected_nodes2[8]])
        node.set_children(expected_nodes2[1], [expected_nodes2[2], expected_nodes2[7]])
        node.set_children(expected_nodes2[2], [expected_nodes2[3], expected_nodes2[6]])
        node.set_children(expected_nodes2[3], [expected_nodes2[4], expected_nodes2[5]])

        new_s1, new_s2 = co.destructive_crossover(self.parents, points)
        self.assertTrue(node.node_equal(expected_nodes1[0], new_s1.root, as_tree=True))
        self.assertTrue(node.node_equal(expected_nodes2[0], new_s2.root, as_tree=True))
        self.assertTrue(self.parents[0] is new_s1)
        self.assertTrue(self.parents[1] is new_s2)
Example #5
0
    def test_initialize(self):
        t_prob = 0
        initializer = RandomInitializer(t_prob, self.max_depth, self.problem)
        s = initializer()
        solution.set_solution_depth(s)
        self.assertEqual(s.depth, self.max_depth)
        self.assertEqual(node.nodes_checker(node.get_all_node(s.root)), None)

        t_prob = 1
        initializer = RandomInitializer(t_prob, self.max_depth, self.problem)
        s = initializer()
        solution.set_solution_depth(s)
        self.assertEqual(s.depth, 0)
Example #6
0
def select_random_points(solution, k):
    """
    Obtain `k` points in the solution at random.
    :param solution: class `Solution`
    :param k: the number of points to obtain
    :return: a list of class `Node`
    """
    # TODO check k <= len(nodelist) or get min(k, len(nodelist))
    # TODO type check if ``solution'' is Solution
    node_list = get_all_node(solution.root)
    points = random.sample(node_list, k=k)

    return points
Example #7
0
 def get_target_node(self, root):
     """
     get a list of target nodes
     :param root: node object. a node of target solution
     :return: list of node object. target node list.
     """
     if self.target_node == 'nonterminal':
         return node.get_all_nonterminal_nodes(root)
     elif self.target_node == 'terminal':
         return node.get_all_nonterminal_nodes(root)
     elif self.target_node == 'all':
         return node.get_all_node(root)
     else:
         msg = '{} is not found'.format(self.target_node)
         raise ValueError(msg)
Example #8
0
 def test_onepoint_mutation(self):
     s2 = self.mutation(self.s1)
     for n in node.get_all_node(s2.root):
         func_list = self.problem.func_bank.get_function_list(n_children=len(n.children))
         self.assertTrue(n.func_id in func_list)
Example #9
0
def _calc_solution_n_nodes(solution):
    # TODO type check if ``solution'' is Solution
    nodes = get_all_node(solution.root)
    n_nodes = len(nodes)

    return n_nodes