test_rmax = False test_realpca = False #True test_alias = False test_complete = False def compare_all_phi(gw): data = [] for s in gw.states: for a in gw.actions: data.append(gw.phi(s, a)) print(next(find_duplicates(data))) # stop iteration? if test_rmax: gw = GridworldGui(nrows=5, ncols=5, endstates=[0], walls=[]) # try: # raise ValueError # for new trace # t = pickle.load(open("rmax_trace.pck")) # except: # t = gw.trace(100, show = True) # pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL) policy0 = np.zeros(gw.nfeatures()) t = [] # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed. w0, weights0 = LSPIRmax(t, 0.003, gw, policy0,
#! /usr/bin/env python """ Author: Jeremy M. Stober Program: TEST_GUI.PY Date: Wednesday, June 6 2012 Description: Code to test gui changes. """ from gridworld.gridworldgui import GridworldGui from lspi import LSPI import numpy as np endstates = [32, 2016, 1024, 1040, 1056, 1072] gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[]) t = gw.trace(10000) z = np.zeros(gw.nfeatures()) #import pdb #pdb.set_trace() #w = LSPI(t,0.0001,gw,z) print(gw.phi(0,0)) print(gw.phi(0,1)) print(w)