示例#1
0
test_rmax = False
test_realpca = False  #True
test_alias = False
test_complete = False


def compare_all_phi(gw):
    data = []
    for s in gw.states:
        for a in gw.actions:
            data.append(gw.phi(s, a))
    print(next(find_duplicates(data)))  # stop iteration?


if test_rmax:
    gw = GridworldGui(nrows=5, ncols=5, endstates=[0], walls=[])

    # try:
    #     raise ValueError # for new trace
    #     t = pickle.load(open("rmax_trace.pck"))
    # except:
    #     t = gw.trace(100, show = True)
    #     pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    policy0 = np.zeros(gw.nfeatures())
    t = []
    # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed.
    w0, weights0 = LSPIRmax(t,
                            0.003,
                            gw,
                            policy0,
示例#2
0
文件: test_lspi.py 项目: stober/lspi
test_walls = False
test_fakepca = False
test_rmax = False
test_realpca = False #True
test_alias = False
test_complete = False

def compare_all_phi(gw):
    data = []
    for s in gw.states:
        for a in gw.actions:
            data.append(gw.phi(s,a))
    print find_duplicates(data).next() # stop iteration?

if test_rmax:
    gw = GridworldGui(nrows = 5, ncols = 5, endstates = [0], walls = [])
    
    # try:
    #     raise ValueError # for new trace
    #     t = pickle.load(open("rmax_trace.pck"))
    # except:
    #     t = gw.trace(100, show = True)
    #     pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    policy0 = np.zeros(gw.nfeatures())
    t = []
    # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed.
    w0, weights0 = LSPIRmax(t, 0.003, gw, policy0, maxiter = 1000, show = True, resample_epsilon = 0.0, rmax = 1000)
    pi = [gw.linear_policy(w0,s) for s in range(gw.nstates)]
    gw.set_arrows(pi)
    gw.background()
示例#3
0
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: TEST_GUI.PY
Date: Wednesday, June  6 2012
Description: Code to test gui changes.
"""

from gridworld.gridworldgui import GridworldGui
from lspi import LSPI
import numpy as np

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[])
t = gw.trace(10000)
z = np.zeros(gw.nfeatures())
#import pdb
#pdb.set_trace()
#w = LSPI(t,0.0001,gw,z)
print(gw.phi(0,0))
print(gw.phi(0,1))
print(w)

示例#4
0
Date: Wednesday, July  4 2012
Description: Try different solution methods for gridworld problems.
"""

from gridworld.gridworldgui import GridworldGui
import pdb
from lspi import LSPI
import numpy as np
from td import TDQ,TD,Sarsa,SampleModelValueIteration
import time

# endstates = [40]
# gw = GridworldGui(nrows=9,ncols=9,endstates=endstates, walls=[])

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates, walls=[])

#gw.updategui=False
#gw.draw_state_labels()

#learner = TDQ(8,81,0.1,0.9,0.9)
#learner = TD(81,0.1,0.9,0.9)
#learner = Sarsa(8,81, 0.3, 0.9,0.9, 0.4)
learner = SampleModelValueIteration(8,81)

# rw_model, transition_model
# pdb.set_trace()
# v,pi = learner.learn(100,gw,verbose=True)

v,pi = gw.value_iteration()
#pi = np.ones(gw.nstates,dtype='int')
示例#5
0
Date: Wednesday, July  4 2012
Description: Try different solution methods for gridworld problems.
"""

from gridworld.gridworldgui import GridworldGui
import pdb
from lspi import LSPI
import numpy as np
from td import TDQ, TD, Sarsa, SampleModelValueIteration
import time

# endstates = [40]
# gw = GridworldGui(nrows=9,ncols=9,endstates=endstates, walls=[])

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32, ncols=64, endstates=endstates, walls=[])

#gw.updategui=False
#gw.draw_state_labels()

#learner = TDQ(8,81,0.1,0.9,0.9)
#learner = TD(81,0.1,0.9,0.9)
#learner = Sarsa(8,81, 0.3, 0.9,0.9, 0.4)
learner = SampleModelValueIteration(8, 81)

# rw_model, transition_model
# pdb.set_trace()
# v,pi = learner.learn(100,gw,verbose=True)

v, pi = gw.value_iteration()
#pi = np.ones(gw.nstates,dtype='int')
示例#6
0
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: TEST_GUI.PY
Date: Wednesday, June  6 2012
Description: Code to test gui changes.
"""

from gridworld.gridworldgui import GridworldGui
from lspi import LSPI
import numpy as np

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[])
t = gw.trace(10000)
z = np.zeros(gw.nfeatures())
#import pdb
#pdb.set_trace()
#w = LSPI(t,0.0001,gw,z)
print gw.phi(0,0)
print gw.phi(0,1)
print w