Beispiel #1
0
    def _oneInteraction(self):

        resetInThisRound = False

        old = (self.XA, self.switch_state)
        (self.XA, self.switch_state) = reverseStateMapper[level.state]
        payoff = stateToRewardMapper[level.state]

        self.acc_reward += payoff * 10
        if self.collect_data:
            self.count += 1
            if payoff > 0:
                self.collect_episode_data_file.write(str(self.count) + "\n")
                self.count = 0
            if self.stepid % interval == 0:
                self.collect_reward_data_file.write(
                    str(self.acc_reward / float(interval)) + "\n")
                self.acc_reward = 0
            if self.stepid % 100000 == 0:
                pass

        if self.stepid % interval == 0:
            sys.stdout.write("\033[K")
            sys.stdout.write(
                "[{2}{3}] ({0}/{1}) | alpha = {4} | epsilon = {5}\n".format(
                    self.stepid, MAX_STEPS,
                    '#' * int(math.floor(self.stepid / float(MAX_STEPS) * 20)),
                    ' ' * int(
                        (20 -
                         math.floor(self.stepid / float(MAX_STEPS) * 20))),
                    learner.alpha, learner.explorer.exploration))
            sys.stdout.write("\033[F")

        if self.stepid >= MAX_STEPS:
            print("\nSimulation done!")

            sys.exit()

        if payoff > 0:
            # episode done
            if save_file != None:
                controller.params.reshape(
                    controller.numRows,
                    controller.numColumns).tofile(save_file)
            learner.alpha *= 0.999999
            learner.explorer.exploration *= 0.999999
        if level.state == errorState:
            level.reset()

        self.isCrashed = False
        if not self.isPaused:
            return Experiment._oneInteraction(self)
        else:
            return self.stepid
class BoxSearchRunner():

  def __init__(self, mode):
    self.mode = mode
    cu.mem('Reinforcement Learning Started')
    self.environment = BoxSearchEnvironment(config.get(mode+'Database'), mode, config.get(mode+'GroundTruth'))
    self.controller = QNetwork()
    cu.mem('QNetwork controller created')
    self.learner = None
    self.agent = BoxSearchAgent(self.controller, self.learner)
    self.task = BoxSearchTask(self.environment, config.get(mode+'GroundTruth'))
    self.experiment = Experiment(self.task, self.agent)

  def runEpoch(self, interactions, maxImgs):
    img = 0
    s = cu.tic()
    while img < maxImgs:
      k = 0
      while not self.environment.episodeDone and k < interactions:
        self.experiment._oneInteraction()
        k += 1
      self.agent.learn()
      self.agent.reset()
      self.environment.loadNextEpisode()
      img += 1
    s = cu.toc('Run epoch with ' + str(maxImgs) + ' episodes', s)

  def run(self):
    if self.mode == 'train':
      self.agent.persistMemory = True
      self.agent.startReplayMemory(len(self.environment.imageList), config.geti('trainInteractions'))
      self.train()
    elif self.mode == 'test':
      self.agent.persistMemory = False
      self.test()

  def train(self):
    networkFile = config.get('networkDir') + config.get('snapshotPrefix') + '_iter_' + config.get('trainingIterationsPerBatch') + '.caffemodel'
    interactions = config.geti('trainInteractions')
    minEpsilon = config.getf('minTrainingEpsilon')
    epochSize = len(self.environment.imageList)/1
    epsilon = 1.0
    self.controller.setEpsilonGreedy(epsilon, self.environment.sampleAction)
    epoch = 1
    exEpochs = config.geti('explorationEpochs')
    while epoch <= exEpochs:
      s = cu.tic()
      print 'Epoch',epoch,': Exploration (epsilon=1.0)'
      self.runEpoch(interactions, len(self.environment.imageList))
      self.task.flushStats()
      self.doValidation(epoch)
      s = cu.toc('Epoch done in ',s)
      epoch += 1
    self.learner = QLearning()
    self.agent.learner = self.learner
    egEpochs = config.geti('epsilonGreedyEpochs')
    while epoch <= egEpochs + exEpochs:
      s = cu.tic()
      epsilon = epsilon - (1.0-minEpsilon)/float(egEpochs)
      if epsilon < minEpsilon: epsilon = minEpsilon
      self.controller.setEpsilonGreedy(epsilon, self.environment.sampleAction)
      print 'Epoch',epoch ,'(epsilon-greedy:{:5.3f})'.format(epsilon)
      self.runEpoch(interactions, epochSize)
      self.task.flushStats()
      self.doValidation(epoch)
      s = cu.toc('Epoch done in ',s)
      epoch += 1
    maxEpochs = config.geti('exploitLearningEpochs') + exEpochs + egEpochs
    while epoch <= maxEpochs:
      s = cu.tic()
      print 'Epoch',epoch,'(exploitation mode: epsilon={:5.3f})'.format(epsilon)
      self.runEpoch(interactions, epochSize)
      self.task.flushStats()
      self.doValidation(epoch)
      s = cu.toc('Epoch done in ',s)
      shutil.copy(networkFile, networkFile + '.' + str(epoch))
      epoch += 1

  def test(self):
    interactions = config.geti('testInteractions')
    self.controller.setEpsilonGreedy(config.getf('testEpsilon'))
    self.runEpoch(interactions, len(self.environment.imageList))

  def doValidation(self, epoch):
    if epoch % config.geti('validationEpochs') != 0:
      return
    auxRL = BoxSearchRunner('test')
    auxRL.run()
    indexType = config.get('evaluationIndexType')
    category = config.get('category')
    if indexType == 'pascal':
      categories, catIndex = bse.get20Categories()
    elif indexType == 'relations':
      categories, catIndex = bse.getCategories()
    elif indexType == 'finetunedRelations':
      categories, catIndex = bse.getRelationCategories()
    if category in categories:
        catI = categories.index(category)
    else:
        catI = -1
    scoredDetections = bse.loadScores(config.get('testMemory'), catI)
    groundTruthFile = config.get('testGroundTruth')
    #ps,rs = bse.evaluateCategory(scoredDetections, 'scores', groundTruthFile)
    pl,rl = bse.evaluateCategory(scoredDetections, 'landmarks', groundTruthFile)
    line = lambda x,y,z: x + '\t{:5.3f}\t{:5.3f}\n'.format(y,z)
    #print line('Validation Scores:',ps,rs)
    print line('Validation Landmarks:',pl,rl)
Beispiel #3
0
    def _oneInteraction(self):
        global draw

        resetInThisRound = False

        # Process events
        for event in pygame.event.get():
            if event.type == pygame.locals.QUIT or (
                    event.type == pygame.locals.KEYDOWN and event.key
                    in [pygame.locals.K_ESCAPE, pygame.locals.K_q]):
                return
            if (event.type == pygame.locals.KEYDOWN
                    and event.key == pygame.locals.K_SPACE):
                print len(controller.params)
                print controller.params.reshape(controller.numRows,
                                                controller.numColumns)
                controller.params.reshape(
                    controller.numRows,
                    controller.numColumns).tofile("test.table")
                self.isPaused = not self.isPaused
            if (event.type == pygame.locals.KEYDOWN
                    and event.key == pygame.locals.K_r):
                resetInThisRound = True
            if (event.type == pygame.locals.KEYDOWN
                    and event.key == pygame.locals.K_PLUS):
                self.speed += 1
            if (event.type == pygame.locals.KEYDOWN
                    and event.key == pygame.locals.K_MINUS):
                self.speed = max(self.speed - 1, 1)
            if (event.type == pygame.locals.KEYDOWN
                    and event.key == pygame.locals.K_d):
                draw = not draw

        # if self.isCrashed:
#           self.isCrashed = False
#           # level.reset()
#
# Update
        if resetInThisRound:
            print "reset"
            level.reset()

        old = (self.robotXA, self.robotYA)
        (self.robotXA, self.robotYA, csf,
         payoff) = reverseStateMapper[level.state]

        if not self.isCrashed and enemies_enabled:
            enemy_handler.update(old)
            for e in enemy_handler.getEnemyPositions():
                if (self.robotXA, self.robotYA) == e:
                    self.isCrashed = True
                    level.penalty += 1
                    self.acc_reward -= 1
                    if shield_options > 0 and not args.huge_neg_reward:
                        print "Shields are not allowed to make errors!"
                        exit()
                    break

        if (self.robotXA + 1, self.robotYA + 1) in bombs:
            self.bomb_counter += 1
            if self.bomb_counter == 4:
                self.isCrashed = True
                level.penalty += 1
                self.acc_reward -= 1
                if shield_options > 0 and not args.huge_neg_reward:
                    print "Shields are not allowed to make errors!"
                    exit()
        else:
            self.bomb_counter = 0

        if draw:
            q_max = 0
            for state in range(len(reverseStateMapper) - 1):
                q_max = max(q_max, max(controller.getActionValues(state)))

            # Draw Field
            for x in xrange(0, xsize):
                for y in xrange(0, ysize):
                    paletteColor = imageData[y * xsize + x]
                    color = palette[paletteColor * 3:paletteColor * 3 + 3]
                    pygame.draw.rect(self.screenBuffer, color,
                                     ((x + 1) * MAGNIFY,
                                      (y + 1) * MAGNIFY, MAGNIFY, MAGNIFY), 0)

            # Draw boundary
            if self.robotXA == -1 or self.isCrashed:
                boundaryColor = (255, 0, 0)
            else:
                boundaryColor = (64, 64, 64)
            pygame.draw.rect(self.screenBuffer, boundaryColor,
                             (0, 0, MAGNIFY * (xsize + 2), MAGNIFY), 0)
            pygame.draw.rect(self.screenBuffer, boundaryColor,
                             (0, MAGNIFY, MAGNIFY, MAGNIFY * (ysize + 1)), 0)
            pygame.draw.rect(self.screenBuffer, boundaryColor,
                             (MAGNIFY * (xsize + 1), MAGNIFY, MAGNIFY,
                              MAGNIFY * (ysize + 1)), 0)
            pygame.draw.rect(self.screenBuffer, boundaryColor,
                             (MAGNIFY, MAGNIFY *
                              (ysize + 1), MAGNIFY * xsize, MAGNIFY), 0)
            # pygame.draw.rect(screenBuffer,boundaryColor,(0,0,MAGNIFY*(xsize+2),MAGNIFY),0)

            # Draw cell frames
            for x in xrange(0, xsize):
                for y in xrange(0, ysize):
                    pygame.draw.rect(self.screenBuffer, (0, 0, 0),
                                     ((x + 1) * MAGNIFY,
                                      (y + 1) * MAGNIFY, MAGNIFY, MAGNIFY), 1)
                    if (x + 1, y + 1) in bombs:
                        self.screenBuffer.blit(self.bombImage,
                                               ((x + 1) * MAGNIFY + 1,
                                                (y + 1) * MAGNIFY + 1))
            pygame.draw.rect(self.screenBuffer, (0, 0, 0),
                             (MAGNIFY - 1, MAGNIFY - 1, MAGNIFY * xsize + 2,
                              MAGNIFY * ysize + 2), 1)

            # Draw "Good" Robot
            if self.robotXA != -1:
                pygame.draw.circle(
                    self.screenBuffer, (192, 32, 32),
                    ((self.robotXA + 1) * MAGNIFY + MAGNIFY / 2,
                     (self.robotYA + 1) * MAGNIFY + MAGNIFY / 2),
                    MAGNIFY / 3 - 2, 0)
                pygame.draw.circle(
                    self.screenBuffer, (255, 255, 255),
                    ((self.robotXA + 1) * MAGNIFY + MAGNIFY / 2,
                     (self.robotYA + 1) * MAGNIFY + MAGNIFY / 2),
                    MAGNIFY / 3 - 1, 1)
                pygame.draw.circle(
                    self.screenBuffer, (0, 0, 0),
                    ((self.robotXA + 1) * MAGNIFY + MAGNIFY / 2,
                     (self.robotYA + 1) * MAGNIFY + MAGNIFY / 2), MAGNIFY / 3,
                    1)

            # Draw "Bad" Robots
            if enemies_enabled:
                for (e_x, e_y) in enemy_handler.getEnemyPositions():
                    pygame.draw.circle(self.screenBuffer, (32, 32, 192),
                                       ((e_x + 1) * MAGNIFY + MAGNIFY / 2,
                                        (e_y + 1) * MAGNIFY + MAGNIFY / 2),
                                       MAGNIFY / 3 - 2, 0)
                    pygame.draw.circle(self.screenBuffer, (255, 255, 255),
                                       ((e_x + 1) * MAGNIFY + MAGNIFY / 2,
                                        (e_y + 1) * MAGNIFY + MAGNIFY / 2),
                                       MAGNIFY / 3 - 1, 1)
                    pygame.draw.circle(self.screenBuffer, (0, 0, 0),
                                       ((e_x + 1) * MAGNIFY + MAGNIFY / 2,
                                        (e_y + 1) * MAGNIFY + MAGNIFY / 2),
                                       MAGNIFY / 3, 1)

            # zone_width = danger_zone[-1][0] - danger_zone[0][0] + 1

    #        zone_height = danger_zone[-1][1] - danger_zone[0][1] + 1
    # pygame.draw.rect(screenBuffer,(200,200,0),(MAGNIFY*(danger_zone[0][0]+1),MAGNIFY*(danger_zone[0][1]+1),MAGNIFY*zone_width,MAGNIFY*zone_height),5)

    # Flip!
            self.screen.blit(self.screenBuffer, (0, 0))
            pygame.display.flip()

            # Make the transition
            if not self.isPaused:
                # Done
                self.clock.tick(self.speed)
            else:
                self.clock.tick(3)

        self.acc_reward += payoff * 10
        if self.collect_data:
            self.count += 1
            if payoff > 0:
                self.collect_episode_data_file.write(str(self.count) + "\n")
                self.count = 0
            if self.stepid % 100 == 0:
                self.collect_reward_data_file.write(
                    str(self.acc_reward / 100.) + "\n")
                self.acc_reward = 0
            if self.stepid % 100000 == 0:
                pass
                # print learner.alpha
#                 print learner.explorer.exploration
#                 print self.stepid
#                 raw_input()

        if self.stepid % 100 == 0:
            sys.stdout.write("\033[K")
            sys.stdout.write(
                "[{2}{3}] ({0}/{1}) | alpha = {4} | epsilon = {5}\n".format(
                    self.stepid, MAX_STEPS,
                    '#' * int(math.floor(self.stepid / float(MAX_STEPS) * 20)),
                    ' ' * int(
                        (20 -
                         math.floor(self.stepid / float(MAX_STEPS) * 20))),
                    learner.alpha, learner.explorer.exploration))
            sys.stdout.write("\033[F")

        if self.stepid >= MAX_STEPS:
            print "\nSimulation done!"

            sys.exit()

        if payoff > 0:
            # episode done
            if save_file != None:
                controller.params.reshape(
                    controller.numRows,
                    controller.numColumns).tofile(save_file)
            learner.alpha *= 1.  #0.999
            learner.explorer.exploration *= 1.  #0.999

        self.isCrashed = False
        if not self.isPaused:
            return Experiment._oneInteraction(self)
        else:
            return self.stepid
Beispiel #4
0
class BoxSearchRunner():
    def __init__(self, mode):
        self.mode = mode
        cu.mem('Reinforcement Learning Started')
        self.environment = BoxSearchEnvironment(
            config.get(mode + 'Database'), mode,
            config.get(mode + 'GroundTruth'))
        self.controller = QNetwork()
        cu.mem('QNetwork controller created')
        self.learner = None
        self.agent = BoxSearchAgent(self.controller, self.learner)
        self.task = BoxSearchTask(self.environment,
                                  config.get(mode + 'GroundTruth'))
        self.experiment = Experiment(self.task, self.agent)

    def runEpoch(self, interactions, maxImgs):
        img = 0
        s = cu.tic()
        while img < maxImgs:
            k = 0
            while not self.environment.episodeDone and k < interactions:
                self.experiment._oneInteraction()
                k += 1
            self.agent.learn()
            self.agent.reset()
            self.environment.loadNextEpisode()
            img += 1
        s = cu.toc('Run epoch with ' + str(maxImgs) + ' episodes', s)

    def run(self):
        if self.mode == 'train':
            self.agent.persistMemory = True
            self.agent.startReplayMemory(len(self.environment.imageList),
                                         config.geti('trainInteractions'))
            #self.agent.assignPriorMemory(self.environment.priorMemory)
            self.train()
        elif self.mode == 'test':
            self.agent.persistMemory = False
            self.test()

    def train(self):
        networkFile = config.get('networkDir') + config.get(
            'snapshotPrefix') + '_iter_' + config.get(
                'trainingIterationsPerBatch') + '.caffemodel'
        interactions = config.geti('trainInteractions')
        minEpsilon = config.getf('minTrainingEpsilon')
        epochSize = len(self.environment.imageList) / 1
        epsilon = 1.0
        self.controller.setEpsilonGreedy(epsilon,
                                         self.environment.sampleAction)
        epoch = 1
        exEpochs = config.geti('explorationEpochs')
        while epoch <= exEpochs:
            s = cu.tic()
            print 'Epoch', epoch, ': Exploration (epsilon=1.0)'
            self.runEpoch(interactions, len(self.environment.imageList))
            self.task.flushStats()
            s = cu.toc('Epoch done in ', s)
            epoch += 1
        self.learner = QLearning()
        self.agent.learner = self.learner
        egEpochs = config.geti('epsilonGreedyEpochs')
        while epoch <= egEpochs + exEpochs:
            s = cu.tic()
            epsilon = epsilon - (1.0 - minEpsilon) / float(egEpochs)
            if epsilon < minEpsilon: epsilon = minEpsilon
            self.controller.setEpsilonGreedy(epsilon,
                                             self.environment.sampleAction)
            print 'Epoch', epoch, '(epsilon-greedy:{:5.3f})'.format(epsilon)
            self.runEpoch(interactions, epochSize)
            self.task.flushStats()
            self.doValidation(epoch)
            s = cu.toc('Epoch done in ', s)
            epoch += 1
        maxEpochs = config.geti('exploitLearningEpochs') + exEpochs + egEpochs
        while epoch <= maxEpochs:
            s = cu.tic()
            print 'Epoch', epoch, '(exploitation mode: epsilon={:5.3f})'.format(
                epsilon)
            self.runEpoch(interactions, epochSize)
            self.task.flushStats()
            self.doValidation(epoch)
            s = cu.toc('Epoch done in ', s)
            shutil.copy(networkFile, networkFile + '.' + str(epoch))
            epoch += 1

    def test(self):
        interactions = config.geti('testInteractions')
        self.controller.setEpsilonGreedy(config.getf('testEpsilon'))
        self.runEpoch(interactions, len(self.environment.imageList))

    def doValidation(self, epoch):
        if epoch % config.geti('validationEpochs') != 0:
            return
        auxRL = BoxSearchRunner('test')
        auxRL.run()
        indexType = config.get('evaluationIndexType')
        category = config.get('category')
        if indexType == 'pascal':
            categories, catIndex = bse.get20Categories()
        elif indexType == 'relations':
            categories, catIndex = bse.getCategories()
        elif indexType == 'finetunedRelations':
            categories, catIndex = bse.getRelationCategories()
        catI = categories.index(category)
        scoredDetections = bse.loadScores(config.get('testMemory'), catI)
        groundTruthFile = config.get('testGroundTruth')
        ps, rs = bse.evaluateCategory(scoredDetections, 'scores',
                                      groundTruthFile)
        pl, rl = bse.evaluateCategory(scoredDetections, 'landmarks',
                                      groundTruthFile)
        line = lambda x, y, z: x + '\t{:5.3f}\t{:5.3f}\n'.format(y, z)
        print line('Validation Scores:', ps, rs)
        print line('Validation Landmarks:', pl, rl)