-
Notifications
You must be signed in to change notification settings - Fork 0
/
neatHandler.py
125 lines (107 loc) · 4.58 KB
/
neatHandler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import neat
from circuit import Circuit
import fileinput
import sys
import re
import visualize
class neatHandler:
def __init__(self):
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'config')
self.circuit = Circuit()
self.circuit.start()
self.wireNames = self.circuit.getWires()
self.numInputs = self.circuit.getPathMatrixSize() + len(self.wireNames)*4
self.numOutputs = len(self.wireNames)*4
self.editConfig(config_path)
self.bestFitness = -999999
self.stats = None
self.run(config_path)
def editConfig(self, config_path):
inputSearch = "^num_inputs\s+=\s+(\d)+(\n)?$"
inputReplace = "num_inputs = " + str(self.numInputs) + "\n"
outputSearch = "^num_outputs\s+=\s+(\d)+(\n)?$"
outputReplace = "num_outputs = " + str(self.numOutputs) + "\n"
for line in fileinput.input(config_path, inplace = 1):
line = re.sub(inputSearch, inputReplace, line)
line = re.sub(outputSearch, outputReplace, line)
sys.stdout.write(line)
print ("Config file edited")
def eval_genomes(self, genomes, config):
for genome_id, genome in genomes:
net = neat.nn.FeedForwardNetwork.create(genome, config)
genome.fitness = self.evalNet(net)
if self.bestFitness < genome.fitness:
self.bestFitness = genome.fitness
print("New Best Fitness: "+str(self.circuit.getFitness()))
self.circuit.drawResult()
visualize.plot_stats(self.stats, ylog=False, view=True)
def evalNet(self, net):
self.circuit.restart()
while not(self.circuit.isDone()):
output = net.activate(self.getInputsForCircuit())
if not self.makeMove(output):
break
return self.circuit.getFitness()
def makeMove(self, output):
moves = self.getMoveList(output)
for move in moves:
if move[0](move[2]):
return True
return False
def getMoveList(self, output):
i = 0
moves = []
for wire in self.wireNames:
moves.append((self.circuit.moveNorth, output[i], wire))
i+=1
moves.append((self.circuit.moveEast, output[i], wire))
i+=1
moves.append((self.circuit.moveSouth, output[i], wire))
i+=1
moves.append((self.circuit.moveWest, output[i], wire))
i+=1
moves.sort(key=lambda tup: tup[1])
return moves
def getInputsForCircuit(self):
inputs = []
for row in self.circuit.getPathMatrix():
for i in row:
inputs.append(i)
for wire in self.wireNames:
srow,scol = self.circuit.getWirePosition(wire)
inputs.append(srow / self.circuit.getRows())
inputs.append(scol / self.circuit.getCols())
grow,gcol = self.circuit.getWireGoal(wire)
inputs.append(grow / self.circuit.getRows())
inputs.append(gcol / self.circuit.getCols())
return inputs
def run(self, config_file):
# Load configuration.
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_file)
# Create the population, which is the top-level object for a NEAT run.
p = neat.Population(config)
# Add a stdout reporter to show progress in the terminal.
p.add_reporter(neat.StdOutReporter(True))
self.stats = neat.StatisticsReporter()
p.add_reporter(self.stats)
# Run for up to 300 generations.
winner = p.run(self.eval_genomes, 1000)
self.evalNet(neat.nn.FeedForwardNetwork.create(winner, config))
#os.system("/usr/bin/canberra-gtk-play --id='bell'")
#visualize.draw_net(config, winner, True)
visualize.plot_stats(self.stats, ylog=False, view=True)
#visualize.plot_species(stats, view=True)
print("Turns: " + str(self.circuit.getTotalTurns()))
print("WireLength: "+str(self.circuit.getWireLength()))
print("Completed Wires: "+str(self.circuit.getCompletedWires()))
print("Fitness: "+str(self.circuit.getFitness()))
self.circuit.drawResult()
# Display the winning genome.
#print('\nBest genome:\n{!s}'.format(winner))
#print('\nOutput:')
#winner_net = neat.nn.FeedForwardNetwork.create(winner, config)
n = neatHandler()