def gridworld(world = 'Easy'): if world == 'Easy': #userMap = [[-4, -4, -4, -4, 100], # [-4, 1, -4, 1, -100], # [-4, 1, 1, 1, -4], # [-4, 1, -4, 1, -4], # [-4, -4, -4, -4, -4]] #userMap = [[1,0,0,0], # [0,1,0,0], # [0,1,1,0], # [0,0,0,0]] userMap = [[0,1,0,0,0], [0,1,0,1,0], [0,1,0,0,0], [0,1,1,1,0], [0,0,0,0,0]] else: userMap = [[1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0], [1,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0], [1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0], [1,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0], [0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0], [0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0], [0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0], [0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0], [0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0], [0,0,0,1,1,1,1,1,1,1,1,1,0,0,0,0], [0,0,0,1,1,0,1,1,0,0,0,0,0,0,0,0], [0,0,0,1,1,0,0,1,1,1,1,1,1,1,0,0], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]] n = len(userMap) tmp = deepcopy(userMap) userMap = MapPrinter().mapToMatrix(tmp) maxX = n - 1 maxY = n - 1 # Print the map that is being analyzed print("\n\n*** {} Grid World Analysis ***\n".format(world)) MapPrinter().printMap(MapPrinter.matrixToMap(userMap)); return userMap, maxX, maxY
for i in range(n): for j in range(n): tmp[i][j] = userMap[i][j] userMap = MapPrinter().mapToMatrix(tmp) maxX = maxY = n - 1 gen = BasicGridWorld(userMap, maxX, maxY) domain = gen.generateDomain() initialState = gen.getExampleState(domain) rf = BasicRewardFunction(maxX, maxY, userMap) tf = BasicTerminalFunction(maxX, maxY) env = SimulatedEnvironment(domain, rf, tf, initialState) # Print the map that is being analyzed print "/////{} Grid World Analysis/////\n".format(world) MapPrinter().printMap(MapPrinter.matrixToMap(userMap)) visualizeInitialGridWorld(domain, gen, env) hashingFactory = SimpleHashableStateFactory() increment = MAX_ITERATIONS / NUM_INTERVALS timing = defaultdict(list) rewards = defaultdict(list) steps = defaultdict(list) convergence = defaultdict(list) allStates = getAllStates(domain, rf, tf, initialState) # Value Iteration iterations = range(1, MAX_ITERATIONS + 1) vi = ValueIteration(domain, rf, tf, discount, hashingFactory, -1, 1) vi.setDebugCode(0) vi.performReachabilityFrom(initialState) vi.toggleUseCachedTransitionDynamics(False)