def __init__(self, ctx, generator, driver): '''Creates a new game in world where driver will interact with the game.''' self._context = ctx self._sim = simulation.Simulation(generator) self._driver = driver self._wins = 0 self._losses = 0 self._was_in_terminal_state = False
def testTrain(self): g = tf.Graph() net = pg.PolicyGradientNetwork('testTrain', g, (4, 4)) s = tf.Session(graph=g) with g.as_default(): init = tf.global_variables_initializer() s.run(init) sim = simulation.Simulation(world.Generator(4, 4)) state = sim.to_array() net.train(s, [[(state, 3, 7), (state, 3, -1)], [(state, 0, 1000)]])
def test_interact(self): TEST_ACTION = movement.ACTION_RIGHT q = grid.QTable(-1) q.set((0, 0), TEST_ACTION, 1) player = grid.MachinePlayer(grid.GreedyQ(q), grid.StubLearner()) w = world.World.parse('@.') with patch.object(simulation.Simulation, 'act') as mock_act: sim = simulation.Simulation(world.Static(w)) ctx = context.StubContext() player.interact(ctx, sim) mock_act.assert_called_once_with(TEST_ACTION)
def testPredict(self): g = tf.Graph() net = pg.PolicyGradientNetwork('testPredict', g, (7, 11)) s = tf.Session(graph=g) with g.as_default(): init = tf.global_variables_initializer() s.run(init) sim = simulation.Simulation(world.Generator(11, 7)) state = sim.to_array() [[act], _] = net.predict(s, [state]) self.assertTrue(0 <= act) self.assertTrue(act < len(movement.ALL_ACTIONS))
def testUpdate_lossDecreases(self): w = world.World.parse('@.....$') g = tf.Graph() net = pg.PolicyGradientNetwork('testUpdate_lossDecreases', g, (w.h, w.w)) s = tf.Session(graph=g) with g.as_default(): init = tf.global_variables_initializer() s.run(init) state = simulation.Simulation(world.Static(w)).to_array() losses = [] for _ in range(10): loss, _ = s.run([net.loss, net.update], feed_dict={ net.state: [state], net.action_in: [[1]], net.advantage: [[2]] }) losses.append(loss) self.assertTrue(losses[-1] < losses[0])
def test_to_array(self): w = world.World.parse('$.@^#') sim = simulation.Simulation(world.Static(w)) self.assertTrue( (np.array([[2, 3, 4, 5, 1]], dtype=np.int8) == sim.to_array()) .all())
def test_act_accumulates_score(self): w = world.World.parse('@.') sim = simulation.Simulation(world.Static(w)) sim.act(movement.ACTION_RIGHT) sim.act(movement.ACTION_LEFT) self.assertEqual(-2, sim.score)
def test_in_terminal_state(self): w = world.World.parse('@^') sim = simulation.Simulation(world.Static(w)) self.assertFalse(sim.in_terminal_state) sim.act(movement.ACTION_RIGHT) self.assertTrue(sim.in_terminal_state)