コード例 #1
0
ファイル: grid.py プロジェクト: mikechen66/SimpleRL
def main():
    parser = argparse.ArgumentParser(
        description='Simple Reinforcement Learning.')
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--interactive',
                       action='store_true',
                       help='use the keyboard arrow keys to play')
    group.add_argument('--q',
                       action='store_true',
                       help='play automatically with Q-learning')
    group.add_argument('--pg',
                       action='store_true',
                       help='play automatically with policy gradients')
    parser.add_argument('--random',
                        action='store_true',
                        help='generate a random map')

    args = parser.parse_args()

    ctx = context.Context()

    if args.random:
        generator = world.Generator(25, 15)
    else:
        generator = world.Static(
            world.World.parse('''\
  ########
  #..#...#
  #.@#.$.#
  #.##^^.#
  #......#
  ########
  '''))

    if args.interactive:
        player = HumanPlayer()
    elif args.q:
        q = QTable()
        learner = QLearner(q, 0.05, 0.1)
        policy = EpsilonPolicy(GreedyQ(q), RandomPolicy(), 0.01)
        player = MachinePlayer(policy, learner)
    elif args.pg:
        g = tf.Graph()
        s = tf.Session(graph=g)
        player = policy_gradient.PolicyGradientPlayer(g, s, generator.size)
        with g.as_default():
            init = tf.global_variables_initializer()
            s.run(init)
    else:
        sys.exit(1)

    is_automatic = args.q or args.pg
    if is_automatic:
        # Slow the game down to make it fun? to watch.
        ctx.run_loop.post_task(lambda: time.sleep(0.1), repeat=True)

    game = Game(ctx, generator, player)
    ctx.run_loop.post_task(game.step, repeat=True)

    ctx.start()
コード例 #2
0
    def test_interact(self):
        TEST_ACTION = movement.ACTION_RIGHT
        q = grid.QTable(-1)
        q.set((0, 0), TEST_ACTION, 1)

        player = grid.MachinePlayer(grid.GreedyQ(q), grid.StubLearner())
        w = world.World.parse('@.')
        with patch.object(simulation.Simulation, 'act') as mock_act:
            sim = simulation.Simulation(world.Static(w))
            ctx = context.StubContext()
            player.interact(ctx, sim)
        mock_act.assert_called_once_with(TEST_ACTION)
コード例 #3
0
  def testUpdate_lossDecreases(self):
    w = world.World.parse('@.....$')

    g = tf.Graph()
    net = pg.PolicyGradientNetwork('testUpdate_lossDecreases', g, (w.h, w.w))
    s = tf.Session(graph=g)
    with g.as_default():
      init = tf.global_variables_initializer()
      s.run(init)

    state = simulation.Simulation(world.Static(w)).to_array()
    losses = []
    for _ in range(10):
      loss, _ = s.run([net.loss, net.update], feed_dict={
            net.state: [state],
            net.action_in: [[1]],
            net.advantage: [[2]]
          })
      losses.append(loss)
    self.assertTrue(losses[-1] < losses[0])
コード例 #4
0
 def test_to_array(self):
   w = world.World.parse('$.@^#')
   sim = simulation.Simulation(world.Static(w))
   self.assertTrue(
     (np.array([[2, 3, 4, 5, 1]], dtype=np.int8) == sim.to_array())
     .all())
コード例 #5
0
 def test_act_accumulates_score(self):
   w = world.World.parse('@.')
   sim = simulation.Simulation(world.Static(w))
   sim.act(movement.ACTION_RIGHT)
   sim.act(movement.ACTION_LEFT)
   self.assertEqual(-2, sim.score)
コード例 #6
0
 def test_in_terminal_state(self):
   w = world.World.parse('@^')
   sim = simulation.Simulation(world.Static(w))
   self.assertFalse(sim.in_terminal_state)
   sim.act(movement.ACTION_RIGHT)
   self.assertTrue(sim.in_terminal_state)