def testEscapeToCornersGen(self): # In this test we show how to wrap a Python native model so that you # can call it directly to sample for MCTS, rather than having to define # explicit transition and reward functions. Any native Python object # can be used, as long as it belongs to a class that provides the # following methods: # # class MyModel: # def getS(self): pass # Returns S # def getA(self): pass # Returns A # def getDiscount(self): pass # Returns discount # def isTerminal(self, s): pass # Returns whether the input state is terminal # def sampleSR(self, s, a): pass # Samples a new state-reward *tuple* from the input # In our case we use the MDP.Model as if it was a Python object to wrap # all along, but you can use this to wrap any library you want. # It would look something like: # # mymodel = MyModel() # mm = MDP.GenerativeModelPython(mymodel) mm = MDP.GenerativeModelPython(model) mcts = MDP.MCTSGenerativeModelPython(mm, 10000, 5) self.assertEqual(mcts.sampleAction(1, 10), LEFT) self.assertEqual(mcts.sampleAction(2, 10), LEFT) a = mcts.sampleAction(3, 10) self.assertEqual(a == LEFT or a == DOWN, True) self.assertEqual(mcts.sampleAction(4, 10), UP) self.assertEqual(mcts.sampleAction(8, 10), UP) a = mcts.sampleAction(5, 10) self.assertEqual(a == LEFT or a == UP, True) self.assertEqual(mcts.sampleAction(7, 10), DOWN) self.assertEqual(mcts.sampleAction(11, 10), DOWN) a = mcts.sampleAction(10, 10) self.assertEqual(a == RIGHT or a == DOWN, True) a = mcts.sampleAction(12, 10) self.assertEqual(a == RIGHT or a == UP, True) self.assertEqual(mcts.sampleAction(13, 10), RIGHT) self.assertEqual(mcts.sampleAction(14, 10), RIGHT)
def testCompatibility(self): S, A = 4, 3 exp = MDP.Experience(S, A) visits = [] rewards = [] for s in xrange(0, S): visits.append([]) rewards.append([]) for a in xrange(0, A): rewards[s].append(generator()) visits[s].append([]) for s1 in xrange(0, S): visits[s][a].append(generator()) exp.setVisitsTable(visits) exp.setRewardMatrix(rewards) for s in xrange(0, S): for a in xrange(0, A): visitsSum = 0 for s1 in xrange(0, S): self.assertEqual(exp.getVisits(s, a, s1), visits[s][a][s1]) visitsSum += visits[s][a][s1] self.assertEqual(exp.getVisitsSum(s, a), visitsSum) self.assertEqual(exp.getReward(s, a), rewards[s][a])
def testEscapeToCorners(self): mcts = MDP.MCTSModel(model, 10000, 5) self.assertEqual(mcts.sampleAction(1, 10), LEFT) self.assertEqual(mcts.sampleAction(2, 10), LEFT) a = mcts.sampleAction(3, 10) self.assertEqual(a == LEFT or a == DOWN, True) self.assertEqual(mcts.sampleAction(4, 10), UP) self.assertEqual(mcts.sampleAction(8, 10), UP) a = mcts.sampleAction(5, 10) self.assertEqual(a == LEFT or a == UP, True) self.assertEqual(mcts.sampleAction(7, 10), DOWN) self.assertEqual(mcts.sampleAction(11, 10), DOWN) a = mcts.sampleAction(10, 10) self.assertEqual(a == RIGHT or a == DOWN, True) a = mcts.sampleAction(12, 10) self.assertEqual(a == RIGHT or a == UP, True) self.assertEqual(mcts.sampleAction(13, 10), RIGHT) self.assertEqual(mcts.sampleAction(14, 10), RIGHT)
def testRecording(self): S, A = 5, 6 exp = MDP.Experience(S, A) s, s1, a = 3, 4, 5 rew, negrew, zerorew = 7.4, -4.2, 0.0 self.assertEqual(exp.getVisits(s, a, s1), 0) exp.record(s, a, s1, rew) self.assertEqual(exp.getVisits(s, a, s1), 1) self.assertEqual(exp.getReward(s, a), rew) exp.reset() self.assertEqual(exp.getVisits(s, a, s1), 0) self.assertEqual(exp.getReward(s, a), 0) exp.record(s, a, s1, negrew) self.assertEqual(exp.getVisits(s, a, s1), 1) self.assertEqual(exp.getReward(s, a), negrew) exp.record(s, a, s1, zerorew) self.assertEqual(exp.getVisits(s, a, s1), 2) self.assertEqual(exp.getReward(s, a), negrew / 2.0) self.assertEqual(exp.getVisitsSum(s, a), 2)
def testConstruction(self): S, A = 5, 6 exp = MDP.Experience(S, A) self.assertEqual(exp.getS(), S) self.assertEqual(exp.getA(), A) self.assertEqual(exp.getVisits(0, 0, 0), 0) self.assertEqual(exp.getReward(0, 0), 0.0) self.assertEqual(exp.getVisits(S - 1, A - 1, S - 1), 0) self.assertEqual(exp.getReward(S - 1, A - 1), 0.0)
def testUpdates(self): solver = MDP.QLearning(5, 5, 0.9, 0.5) # State goes to itself, thus needs to consider # next-step value. solver.stepUpdateQ(0, 0, 0, 10) self.assertEqual(solver.getQFunction()[0, 0], 5.0) solver.stepUpdateQ(0, 0, 0, 10) self.assertEqual(solver.getQFunction()[0, 0], 9.75) # Here it does not, so improvement is slower. solver.stepUpdateQ(3, 0, 4, 10) self.assertEqual(solver.getQFunction()[3, 0], 5.0) solver.stepUpdateQ(3, 0, 4, 10) self.assertEqual(solver.getQFunction()[3, 0], 7.50) # Test that index combinations are right. solver.stepUpdateQ(0, 1, 1, 10) self.assertEqual(solver.getQFunction()[0, 1], 5.0) self.assertEqual(solver.getQFunction()[1, 0], 0.0) self.assertEqual(solver.getQFunction()[1, 1], 0.0)
def testEscapeToCorners(self): # This model is done manually, I'll copy the makeCornerProblem # C++ stuff that auto generates these tables soon enough. model = MDP.Model(16, 4) t = [[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0.8], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0.2, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0.2, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0, 0, 0, 0.2, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.2, 0.8], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.8, 0.2, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]] r = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]] model.setTransitionFunction(t) model.setRewardFunction(r) vi = MDP.ValueIteration(1000000, 0.001) bound, vfun, qfun = vi(model) self.assertEqual(bound < vi.getTolerance(), True) p = MDP.QGreedyPolicy(qfun) for a in xrange(0, 4): self.assertEqual(p.getActionProbability(0, a), 0.25) self.assertEqual(p.getActionProbability(6, a), 0.25) self.assertEqual(p.getActionProbability(9, a), 0.25) self.assertEqual(p.getActionProbability(15, a), 0.25) self.assertEqual(p.getActionProbability(1, LEFT), 1.0) self.assertEqual(p.getActionProbability(2, LEFT), 1.0) self.assertEqual(p.getActionProbability(3, LEFT), 0.5) self.assertEqual(p.getActionProbability(3, DOWN), 0.5) self.assertEqual(p.getActionProbability(4, UP), 1.0) self.assertEqual(p.getActionProbability(8, UP), 1.0) self.assertEqual(p.getActionProbability(5, LEFT), 0.5) self.assertEqual(p.getActionProbability(5, UP), 0.5) self.assertEqual(p.getActionProbability(7, DOWN), 1.0) self.assertEqual(p.getActionProbability(11, DOWN), 1.0) self.assertEqual(p.getActionProbability(10, RIGHT), 0.5) self.assertEqual(p.getActionProbability(10, DOWN), 0.5) self.assertEqual(p.getActionProbability(12, RIGHT), 0.5) self.assertEqual(p.getActionProbability(12, UP), 0.5) self.assertEqual(p.getActionProbability(13, RIGHT), 1.0) self.assertEqual(p.getActionProbability(14, RIGHT), 1.0) values = vfun.values actions = vfun.actions for s in xrange(0, 16): self.assertEqual(qfun[s, actions[s]], values[s])
import unittest import sys import os sys.path.append(os.getcwd()) from AIToolbox import MDP UP = 0 RIGHT = 1 DOWN = 2 LEFT = 3 # This model is done manually, I'll copy the makeCornerProblem # C++ stuff that auto generates these tables soon enough. model = MDP.Model(16, 4) t = [[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0.2, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0.8, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, 0, 0, 0, 0, 0, 0],
def solve_mdp(horizon, tolerance, discount=0.9): """ Construct the gridworld MDP, and solve it using value iteration. Print the best found policy for sample states. Returns ------- solution: tuple First element is a boolean that indicates whether the method has converged. The second element is the value function. The third element is the Q-value function, from which a policy can be derived. """ print time.strftime("%H:%M:%S"), "- Constructing MDP..." # Statespace contains the tiger (x, y) and antelope (x, y). Note that # this is a very naive state representation: many of these states can be # aggregated! We leave this as an exercise to the reader :) # S = [(t_x, t_y, a_x, a_y), .. ] S = list(itertools.product(range(SQUARE_SIZE), repeat=4)) # A = tiger actions A = ['stand', 'up', 'down', 'left', 'right'] # T gives the transition probability for every s, a, s' triple. T = [] for state in range(len(S)): coord = decodeState(state) T.append([[ getTransitionProbability(coord, action, decodeState(next_state)) for next_state in range(len(S)) ] for action in A]) # R gives the reward associated with every s, a, s' triple. In the current # example, we only specify reward for s', but we still need to give the # entire |S|x|A|x|S| array. reward_row = [ getReward(decodeState(next_state)) for next_state in range(len(S)) ] R = [[reward_row for _ in A] for _ in S] # set up the model model = MDP.SparseModel(len(S), len(A)) model.setTransitionFunction(T) model.setRewardFunction(R) model.setDiscount(discount) # Perform value iteration print time.strftime( "%H:%M:%S" ), "- Solving MDP using ValueIteration(horizon={}, tolerance={})".format( horizon, tolerance) solver = MDP.ValueIteration(horizon, tolerance) solution = solver(model) print time.strftime( "%H:%M:%S"), "- Converged:", solution[0] < solver.getTolerance() _, value_function, q_function = solution policy = MDP.Policy(len(S), len(A), value_function) s = randint(0, SQUARE_SIZE**4 - 1) while model.isTerminal(s): s = randint(0, SQUARE_SIZE**4 - 1) totalReward = 0 for t in xrange(100): printState(decodeState(s)) if model.isTerminal(s): break a = policy.sampleAction(s) s1, r = model.sampleSR(s, a) totalReward += r s = s1 goup(SQUARE_SIZE) state = encodeState(coord) # Sleep 1 second so the user can see what is happening. time.sleep(1)