Example #1
0
def test_graph():
    nodes = [
        Node(1, 'input'),
        Node(2, 'input'),
        Node(3, 'output'),
        Node(4, 'output'),
        Node(5),
        Node(6),
        Node(7)
    ]

    connections = [
        Connection(nodes[0], nodes[4], id_=1),
        Connection(nodes[4], nodes[5], id_=2),
        Connection(nodes[5], nodes[2], id_=3),
        Connection(nodes[1], nodes[6], id_=4),
        Connection(nodes[6], nodes[3], id_=5),
        Connection(nodes[5], nodes[6], id_=6),
        Connection(nodes[6], nodes[4], id_=7)
    ]

    graph = Graph(nodes, connections)
    graph.get_computation_order()
    print(graph.computation_order)
Example #2
0
class Genom:
    def __init__(self, reg: Registry, connections: list[Connection] = None):
        # global registry to get nodes and connections from
        self.registry = reg

        # for ordering
        self.net_graph, self.computation_order = None, None
        self.ready = False
        self.normal_ops = 0

        # nodes and connections
        self.nodes = []
        self.connections = []

        # Fitness
        self.fitness = None

        # random connections and weights
        if connections is None:
            self.initial_connections()
        # initialize with existing connections
        else:
            self.inherit_connections(connections)

    def forward(self, x: np.array):
        if not self.ready:
            self.ordered_graph()

        for n in self.nodes:
            n.reset_vals()

        assert x.size == self.registry.inputs, 'size of input {x.size} must match input nodes {self.registry.inputs}'

        # assign input values from x
        for i, n in enumerate([no for no in self.nodes if no.type == 'input']):
            n.in_val = x[i]

        for ob in self.computation_order:
            ob.eval_ordered()

        out = self.get_output()

        return out

    def forward_rec(self, x: np.array):
        if not self.ready:
            self.ordered_graph()

        assert x.size == self.registry.inputs, 'size of input {x.size} must match input nodes {self.registry.inputs}'

        # clear input values for all nodes
        for n in self.nodes:
            n.reset_vals()

        # assign input values from x
        for i, n in enumerate([no for no in self.nodes if no.type == 'input']):
            n.in_val = x[i]

        # calculate all recurrent connections first (output also recurrent, so do this later)
        for ob in self.computation_order[self.normal_ops:]:
            if isinstance(ob, Connection):
                if ob.n2.type == 'output':
                    print('dd')
                    continue
            else:
                raise ValueError(
                    'there shouldn\'t be a node in recurrent operations list... some error in graph.py'
                )
            ob.eval_recurrent()

        # standard operations in ordered graph
        for ob in self.computation_order[:self.normal_ops]:
            ob.eval_ordered()

        out = self.get_output()
        # print(out)
        return out

    def get_output(self):
        out = [n for n in self.nodes if n.type == 'output']
        ret = np.array([n.out_val for n in out])
        # print(ret)
        return softmax(ret)

    def initial_connections(self):
        self.ready = False

        self.nodes = deepcopy(self.registry.nodes[:self.registry.inputs +
                                                  self.registry.outputs + 1])
        for n in self.nodes:
            n.default()

        inputs = [
            n for n in self.nodes if (n.type == 'input') or (n.type == 'bias')
        ]
        outputs = [n for n in self.nodes if n.type == 'output']
        for _ in range(len(outputs)):
            n1 = random.choice(inputs)
            assert n1.type in ['input', 'bias']
            n2 = outputs.pop(0)
            assert n2.type == 'output'
            c = self.registry.get_connection(n1, n2)
            c.rand_weight()
            self.connections.append(c)

    def inherit_connections(self, connections):
        self.connections = deepcopy(connections)
        for c in self.connections:
            if c.n1 not in self.nodes:
                n1 = self.registry.get_node(c.n1.id)
                self.nodes.append(n1)
            else:
                n1 = next(x for x in self.nodes if x == c.n1)

            if c.n2 not in self.nodes:
                n2 = self.registry.get_node(c.n2.id)
                self.nodes.append(n2)
            else:
                n2 = next(x for x in self.nodes if x == c.n2)
            c.set_nodes(n1, n2)
        self.sort_nodes_connections()

    def set_fitness(self, f: float):
        self.fitness = f

    def get_gene(self):
        ret = dict()
        for c in self.connections:
            ret[c.id] = c
        return ret

    """ *** Mutations *** """

    def apply_mutations(self):
        self.mutate_add_node()
        self.mutate_link()
        self.mutate_weight_shift()
        self.mutate_enable_disable_connection()

    # add a new connection randomly, recurrent connections are possible
    def mutate_link(self):
        self.ready = False
        if random.random() < 1 - P_NEW_LINK:
            return

        n1 = random.choice(self.nodes)
        n2 = random.choice(self.nodes)

        # if this connection exists --> enable it
        if Connection(n1, n2) in self.connections:
            c = next(c for c in self.connections if c == Connection(n1, n2))
            c.enable()
        # else if this connection is possible --> create it
        elif (n1.type != 'output') and ((n2.type != 'bias') or
                                        (n2.type != 'input')):
            c = self.registry.get_connection(n1, n2)
            c.rand_weight()
            self.connections.append(c)
            self.sort_nodes_connections()

    # add node in connection, old connection disabled
    def mutate_add_node(self):
        self.ready = False
        if random.random() < 1 - P_NEW_NODE:
            return

        # select random connection, disable it
        c = random.choice(self.connections)
        c.disable()

        # get the splitting node
        n_add = self.registry.split_connection(c)

        # check if max hidden size is reached
        if not n_add:
            return

        # create two new connections instead
        c1_add = self.registry.get_connection(c.n1, n_add)
        c1_add.weight = 1
        c2_add = self.registry.get_connection(n_add, c.n2)
        c2_add.weight = c.weight

        # append new connections and node
        self.nodes.append(n_add)
        self.connections.append(c1_add)
        self.connections.append(c2_add)
        self.sort_nodes_connections()

    # randomly enable/disable a connection
    def mutate_enable_disable_connection(self):
        self.ready = False
        if random.random() < 1 - P_ENDISABLE:
            return
        # for c in self.connections:
        #     if c.enabled is False:
        #         c.enable()
        #         break
        c = random.choice(self.connections)
        c.en_dis_able()
        self.sort_nodes_connections()

    # shift weight
    def mutate_weight_shift(self):
        self.ready = False
        # weight * [0, 2]
        if random.random() < 1 - P_WEIGHT:
            return
        for c in self.connections:
            if random.random() < .9:
                c.weight += random.uniform(-.5, .5)
            else:
                c.weight = random.uniform(-2, 2)
        self.sort_nodes_connections()

    """ *** prepare for forward *** """

    def sort_nodes_connections(self):
        self.connections = sorted(self.connections, key=lambda x: x.id)
        self.nodes = sorted(self.nodes, key=lambda x: x.id)

    def ordered_graph(self):
        # get ordered nodes and connections
        self.sort_nodes_connections()
        # self.set_next_nodes()
        self.net_graph = Graph(self.nodes, self.connections)
        self.computation_order, self.normal_ops = self.net_graph.get_computation_order(
        )

        self.ready = True

    """ OLD, don't use