def test_graph(): nodes = [ Node(1, 'input'), Node(2, 'input'), Node(3, 'output'), Node(4, 'output'), Node(5), Node(6), Node(7) ] connections = [ Connection(nodes[0], nodes[4], id_=1), Connection(nodes[4], nodes[5], id_=2), Connection(nodes[5], nodes[2], id_=3), Connection(nodes[1], nodes[6], id_=4), Connection(nodes[6], nodes[3], id_=5), Connection(nodes[5], nodes[6], id_=6), Connection(nodes[6], nodes[4], id_=7) ] graph = Graph(nodes, connections) graph.get_computation_order() print(graph.computation_order)
class Genom: def __init__(self, reg: Registry, connections: list[Connection] = None): # global registry to get nodes and connections from self.registry = reg # for ordering self.net_graph, self.computation_order = None, None self.ready = False self.normal_ops = 0 # nodes and connections self.nodes = [] self.connections = [] # Fitness self.fitness = None # random connections and weights if connections is None: self.initial_connections() # initialize with existing connections else: self.inherit_connections(connections) def forward(self, x: np.array): if not self.ready: self.ordered_graph() for n in self.nodes: n.reset_vals() assert x.size == self.registry.inputs, 'size of input {x.size} must match input nodes {self.registry.inputs}' # assign input values from x for i, n in enumerate([no for no in self.nodes if no.type == 'input']): n.in_val = x[i] for ob in self.computation_order: ob.eval_ordered() out = self.get_output() return out def forward_rec(self, x: np.array): if not self.ready: self.ordered_graph() assert x.size == self.registry.inputs, 'size of input {x.size} must match input nodes {self.registry.inputs}' # clear input values for all nodes for n in self.nodes: n.reset_vals() # assign input values from x for i, n in enumerate([no for no in self.nodes if no.type == 'input']): n.in_val = x[i] # calculate all recurrent connections first (output also recurrent, so do this later) for ob in self.computation_order[self.normal_ops:]: if isinstance(ob, Connection): if ob.n2.type == 'output': print('dd') continue else: raise ValueError( 'there shouldn\'t be a node in recurrent operations list... some error in graph.py' ) ob.eval_recurrent() # standard operations in ordered graph for ob in self.computation_order[:self.normal_ops]: ob.eval_ordered() out = self.get_output() # print(out) return out def get_output(self): out = [n for n in self.nodes if n.type == 'output'] ret = np.array([n.out_val for n in out]) # print(ret) return softmax(ret) def initial_connections(self): self.ready = False self.nodes = deepcopy(self.registry.nodes[:self.registry.inputs + self.registry.outputs + 1]) for n in self.nodes: n.default() inputs = [ n for n in self.nodes if (n.type == 'input') or (n.type == 'bias') ] outputs = [n for n in self.nodes if n.type == 'output'] for _ in range(len(outputs)): n1 = random.choice(inputs) assert n1.type in ['input', 'bias'] n2 = outputs.pop(0) assert n2.type == 'output' c = self.registry.get_connection(n1, n2) c.rand_weight() self.connections.append(c) def inherit_connections(self, connections): self.connections = deepcopy(connections) for c in self.connections: if c.n1 not in self.nodes: n1 = self.registry.get_node(c.n1.id) self.nodes.append(n1) else: n1 = next(x for x in self.nodes if x == c.n1) if c.n2 not in self.nodes: n2 = self.registry.get_node(c.n2.id) self.nodes.append(n2) else: n2 = next(x for x in self.nodes if x == c.n2) c.set_nodes(n1, n2) self.sort_nodes_connections() def set_fitness(self, f: float): self.fitness = f def get_gene(self): ret = dict() for c in self.connections: ret[c.id] = c return ret """ *** Mutations *** """ def apply_mutations(self): self.mutate_add_node() self.mutate_link() self.mutate_weight_shift() self.mutate_enable_disable_connection() # add a new connection randomly, recurrent connections are possible def mutate_link(self): self.ready = False if random.random() < 1 - P_NEW_LINK: return n1 = random.choice(self.nodes) n2 = random.choice(self.nodes) # if this connection exists --> enable it if Connection(n1, n2) in self.connections: c = next(c for c in self.connections if c == Connection(n1, n2)) c.enable() # else if this connection is possible --> create it elif (n1.type != 'output') and ((n2.type != 'bias') or (n2.type != 'input')): c = self.registry.get_connection(n1, n2) c.rand_weight() self.connections.append(c) self.sort_nodes_connections() # add node in connection, old connection disabled def mutate_add_node(self): self.ready = False if random.random() < 1 - P_NEW_NODE: return # select random connection, disable it c = random.choice(self.connections) c.disable() # get the splitting node n_add = self.registry.split_connection(c) # check if max hidden size is reached if not n_add: return # create two new connections instead c1_add = self.registry.get_connection(c.n1, n_add) c1_add.weight = 1 c2_add = self.registry.get_connection(n_add, c.n2) c2_add.weight = c.weight # append new connections and node self.nodes.append(n_add) self.connections.append(c1_add) self.connections.append(c2_add) self.sort_nodes_connections() # randomly enable/disable a connection def mutate_enable_disable_connection(self): self.ready = False if random.random() < 1 - P_ENDISABLE: return # for c in self.connections: # if c.enabled is False: # c.enable() # break c = random.choice(self.connections) c.en_dis_able() self.sort_nodes_connections() # shift weight def mutate_weight_shift(self): self.ready = False # weight * [0, 2] if random.random() < 1 - P_WEIGHT: return for c in self.connections: if random.random() < .9: c.weight += random.uniform(-.5, .5) else: c.weight = random.uniform(-2, 2) self.sort_nodes_connections() """ *** prepare for forward *** """ def sort_nodes_connections(self): self.connections = sorted(self.connections, key=lambda x: x.id) self.nodes = sorted(self.nodes, key=lambda x: x.id) def ordered_graph(self): # get ordered nodes and connections self.sort_nodes_connections() # self.set_next_nodes() self.net_graph = Graph(self.nodes, self.connections) self.computation_order, self.normal_ops = self.net_graph.get_computation_order( ) self.ready = True """ OLD, don't use