Example #1
0
class Tree(object):
    def __init__(self):
        self.tree_id = None
        self.score = None
        self.tree_type = None

        self.root = None
        self.depth = 0
        self.size = 0

        self.program = []
        self.func_nodes = []
        self.term_nodes = []
        self.input_nodes = []

        self.parser = TreeParser()

    def valid(self, config_input_nodes):
        # convert config input nodes from dict to list of Nodes
        check_list = []
        for node in config_input_nodes:
            check_list.append(node["name"])

        # convert tree input nodes
        tree_input_nodes = []
        for node in self.input_nodes:
            tree_input_nodes.append(node.name)

        result = set(check_list) - set(tree_input_nodes)
        if len(list(result)) == 0:
            return True
        else:
            return False

    def get_linked_node(self, target_node):
        try:
            index = self.program.index(target_node) + 1

            for node in self.program[index:]:
                if node.has_value_node(target_node) is not False:
                    return node
        except ValueError:
            return None

    def replace_node(self, target_node, replace_with, override_update=False):
        linked_node = self.get_linked_node(target_node)
        branch_index = linked_node.has_value_node(target_node)
        linked_node.branches[branch_index] = replace_with

        if override_update is False:
            self.update()

    def equals(self, tree):
        if len(self.program) != len(tree.program):
            return False

        index = 0
        for node in self.program:
            equals = node.equals(tree.program[index])
            if equals is False:
                return False
            index += 1

        return True

    def update_program(self):
        del self.program[:]
        self.program = self.parser.post_order_traverse(self.root)

    def update_func_nodes(self):
        del self.func_nodes[:]
        for node in self.program:
            if node.is_function():
                if node is not self.root:
                    self.func_nodes.append(node)

    def update_term_nodes(self):
        del self.term_nodes[:]
        for node in self.program:
            if node.is_terminal():
                self.term_nodes.append(node)

    def update_input_nodes(self):
        del self.input_nodes[:]
        for node in self.program:
            if node.is_input():
                self.input_nodes.append(node)

    def update_tree_info(self):
        self.size = len(self.program)
        self.branches = len(self.term_nodes) + len(self.input_nodes)

    def update(self):
        self.program = self.parser.parse_tree(self, self.root)

    def __str__(self):
        if self.tree_type == "CLASSIFICATION_TREE":
            return self.parser.parse_classification_tree(self.root)
        else:
            return self.parser.parse_equation(self.root)

    def to_dict(self):
        self_dict = {
            "id": id(self),
            "score": self.score,

            "size": self.size,
            "depth": self.depth,

            "func_nodes_len": len(self.func_nodes),
            "term_nodes_len": len(self.term_nodes),
            "input_nodes_len": len(self.input_nodes),

            "func_nodes": [str(node) for node in self.func_nodes],
            "term_nodes": [str(node) for node in self.term_nodes],
            "input_nodes": [str(node) for node in self.input_nodes],

            "program": str(self)
        }
        return self_dict
Example #2
0
class TreeParserTests(unittest.TestCase):
    def setUp(self):
        random.seed(10)

        self.config = {
            "max_population": 10,

            "tree_generation": {
                "method": "FULL_METHOD",
                "initial_max_depth": 4
            },

            "function_nodes": [
                {"type": "FUNCTION", "name": "ADD", "arity": 2},
                {"type": "FUNCTION", "name": "SUB", "arity": 2},
                {"type": "FUNCTION", "name": "MUL", "arity": 2},
                {"type": "FUNCTION", "name": "DIV", "arity": 2},
                {"type": "FUNCTION", "name": "COS", "arity": 1},
                {"type": "FUNCTION", "name": "SIN", "arity": 1}
            ],

            "terminal_nodes": [
                {"type": "CONSTANT", "value": 1.0},
                {"type": "INPUT", "name": "x"},
                {"type": "INPUT", "name": "y"},
                {"type": "INPUT", "name": "z"}
            ],

            "input_variables": [
                {"name": "x"},
                {"name": "y"},
                {"name": "z"}
            ]
        }

        self.functions = GPFunctionRegistry("SYMBOLIC_REGRESSION")
        self.generator = TreeGenerator(self.config)
        self.parser = TreeParser()

        # create nodes
        left_node = Node(NodeType.CONSTANT, value=1.0)
        right_node = Node(NodeType.CONSTANT, value=2.0)

        cos_func = Node(
            NodeType.FUNCTION,
            name="COS",
            arity=1,
            branches=[left_node]
        )
        sin_func = Node(
            NodeType.FUNCTION,
            name="SIN",
            arity=1,
            branches=[right_node]
        )

        add_func = Node(
            NodeType.FUNCTION,
            name="ADD",
            arity=2,
            branches=[cos_func, sin_func]
        )

        # create tree
        self.tree = Tree()
        self.tree.root = add_func
        self.tree.update_program()
        self.tree.update_func_nodes()
        self.tree.update_term_nodes()

    def tearDown(self):
        del self.config
        del self.generator
        del self.parser

    def test_parse_tree(self):
        # self.parser.print_tree(tree.root)
        program = self.parser.parse_tree(self.tree, self.tree.root)
        for i in program:
            if i.name is not None:
                print i.name
            else:
                print i.value

        self.assertEquals(self.tree.size, 5)
        self.assertEquals(self.tree.depth, 2)

        self.assertEquals(len(self.tree.func_nodes), 2)
        self.assertEquals(len(self.tree.term_nodes), 2)
        self.assertEquals(len(self.tree.input_nodes), 0)

    def test_parse_equation(self):
        # self.parser.print_tree(tree.root)
        equation = self.parser.parse_equation(self.tree.root)
        self.assertEquals(equation, "((COS(1.0)) ADD (SIN(2.0)))")

    def test_tree_to_dict(self):
        solution = {
            'program': [
                {'type': 'CONSTANT', 'value': 1.0},
                {'arity': 1, 'type': 'FUNCTION', 'name': 'COS'},
                {'type': 'CONSTANT', 'value': 2.0},
                {'arity': 1, 'type': 'FUNCTION', 'name': 'SIN'},
                {'arity': 2, 'type': 'FUNCTION', 'root': True, 'name': 'ADD'}
            ]
        }
        results = self.parser.tree_to_dict(self.tree, self.tree.root)
        self.assertEquals(results["program"], solution["program"])