예제 #1
0
def main():
    try:
        pop = Population.load_from_file(constants.res_loc("networks") + pop_name + ".pop")
    except:
        seed = random.randint(0, 1000)
        pop = Population(seed, 100)

    pool = Pool(number_of_processes)
    while True:
        worlds = []
        for net in pop.current_generation:
            nWorld = NeuronalWorld(pop.seed, net)
            worlds.append(nWorld)
            nWorld.generatePlatform()

        # evaluate all neuronal worlds
        fitnesses = pool.map(evaluate, worlds)
        # set the fitness (because multiprocessing)
        for world, fit in zip(worlds, fitnesses):
            world.nn.fitness = fit

        path = constants.res_loc("networks") + pop.name + ".pop"
        pop.save_to_file(path)
        best_fitness =  max(nn.fitness for nn in pop.current_generation)
        print("best fitness:", best_fitness)
        pop.create_next_generation()
        pop.generation_count += 1
예제 #2
0
class NNTraningContext(BaseContext):
    """
    Context for training neuronal networks
    """

    def __init__(self, seed, setContextFunc, population=None, train=True):
        BaseContext.__init__(self, setContextFunc)
        self.seed = seed
        self.pop = Population(seed, 100) if population is None else population
        if train:
            self.worlds = []
            for net in sorted(self.pop.current_generation, key=lambda x: x.fitness, reverse=True):
                nWorld = NeuronalWorld(self.pop.seed, net)
                nWorld.renderer = RenderNeuronalWorld(nWorld)
                self.worlds.append(nWorld)
                nWorld.generatePlatform()
        else:
            best_nn = max((n for n in self.pop.current_generation), key=lambda x: x.fitness)
            nWorld = NeuronalWorld(self.pop.seed, best_nn)
            nWorld.renderer = RenderNeuronalWorld(nWorld)
            self.worlds = [nWorld]
        self.drawmode = 0
        self._train = train

        # the gui overlays
        self._overlays = texhandler.adjustedSurface(texhandler.Textures.overlays, height=48)

        fontObj = SysFont("Monospace", 30, bold=True)
        # mode switch buttons
        self.addElements({
            "bNext": GuiButton(constants.screenWidth - 100, constants.screenHeight - 70, fontObj, "->",
                               width=70).connect(self.buttonModeSwitch, 1),
            "bPrevious": GuiButton(30, constants.screenHeight - 70, fontObj, "<-", width=70).connect(
                self.buttonModeSwitch, -1)
        })

    def calculateDelta(self, clock):
        if self._train:
            return constants.UPS
        else:
            return BaseContext.calculateDelta(self, clock)

    def update(self, t):
        BaseContext.update(self, t)

        done = True
        for world in self.worlds:
            if world.update(constants.UPS):
                done = False

        if done and self._train:
            self.pop.save_to_file(constants.res_loc("networks") + self.pop.name + ".pop")
            self.pop.create_next_generation()
            self.pop.generation_count += 1

            self.worlds = []
            for net in self.pop.current_generation:
                nWorld = NeuronalWorld(self.pop.seed, net)
                nWorld.renderer = RenderNeuronalWorld(nWorld)
                self.worlds.append(nWorld)
                nWorld.generatePlatform()

    def draw(self, screen):
        if self.worlds:
            [self.drawSimple, self.drawNetwork, self.drawSummary, self.drawOverview][self.drawmode](screen)
        BaseContext.draw(self, screen)
        # draw current generation
        renderedGen = SysFont("Monospace", 40, bold=True).render(str(self.pop.generation_count), 1, (50, 50, 50))
        # draw generation overlay (dinosaur egg)
        screen.blit(self._overlays, (485, 650), (336, 0, 48, 48))
        screen.blit(renderedGen, (540, 655))

    def drawSimple(self, screen):
        """
        only draws the first network
        """
        self.worlds[0].renderer.render(screen)

    def drawNetwork(self, screen):
        """
        draws the first network with it's network-graph
        """
        #        world = max(self.worlds, key=lambda w: w.nn.fitness)
        # draw the world
        world = self.worlds[0]
        world.renderer.render(screen)

        networkSurface = pygame.Surface((750, 180)).convert_alpha()
        networkSurface.fill((0, 0, 0, 0))
        # draw the minimap and network
        networkrenderer.render_network(networkSurface, world.nn, world.minimapValues)
        screen.blit(networkSurface, (10, 60))

    def drawSummary(self, screen):
        """
        draws a summary
        """
        x = 30
        y = 50
        fontObj = SysFont("Monospace", 18, bold=True)

        time = 0
        rowHeight = fontObj.get_height() + 2

        for world in self.worlds:
            time = max(time, world.time)
            renderedFitness = fontObj.render("Fitness: {0:.2f}".format(world.nn.fitness), 1, (255, 255, 255))

            if (y + rowHeight > (constants.screenHeight - 95)):
                y = 50
                x += 260
                if x + renderedFitness.get_width() > constants.screenWidth:
                    break
                # draw vertical line
                pygame.draw.line(screen, (100, 100, 100), (x - 10, 45), (x - 10, constants.screenHeight - 100), 2)
            # draw horizontal line (if still on first column)
            elif x < 100:
                pygame.draw.line(screen, (100, 100, 100), (25, y + rowHeight - 3),
                                 (constants.screenWidth - 25, y + rowHeight - 3), 2)

            screen.blit(renderedFitness, (x, y))

            y += rowHeight

        # draw the time
        renderedTime = pygame.font.SysFont("Monospace", 34, bold=True).render("Time: {0:.2f}".format(time), 1,
                                                                              (255, 255, 255))
        screen.blit(renderedTime, ((constants.screenWidth - renderedTime.get_width()) // 2, 10))

    def drawOverview(self, screen):
        """
        draws the best 9 networks
        """
        if len(self.worlds) >= 9:
            partWidth = constants.screenWidth // 3
            partHeight = constants.screenHeight // 3
            bestWords = sorted(self.worlds, key=lambda x: -x.nn.fitness)[:9]
            for y in range(0, 3):
                for x in range(0, 3):
                    surface = pygame.Surface(constants.screenSize)
                    bestWords[3 * y + x].renderer.render(surface)
                    surface = pygame.transform.scale(surface, (partWidth, partHeight))
                    screen.blit(surface, (partWidth * x, partHeight * y))
                    pygame.draw.rect(screen, (0, 0, 0), (partWidth * x, partHeight * y, partWidth, partHeight), 1)
        else:
            self.drawSimple(screen)

    def handleEvent(self, event):
        if BaseContext.handleEvent(self, event):
            return True

        if event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:
                from context.gamepausecontext import GamePauseContext
                self._setContextFunc(GamePauseContext(self, self._setContextFunc))
            elif event.key == pygame.K_TAB:
                self.drawmode = (self.drawmode + 1) % 4
        return False

    """
    button functions
    """

    def buttonModeSwitch(self, inc):
        self.drawmode = (self.drawmode + inc) % 4