示例#1
0
        for a in gw.actions:
            data.append(gw.phi(s, a))
    print(next(find_duplicates(data)))  # stop iteration?


if test_rmax:
    gw = GridworldGui(nrows=5, ncols=5, endstates=[0], walls=[])

    # try:
    #     raise ValueError # for new trace
    #     t = pickle.load(open("rmax_trace.pck"))
    # except:
    #     t = gw.trace(100, show = True)
    #     pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    policy0 = np.zeros(gw.nfeatures())
    t = []
    # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed.
    w0, weights0 = LSPIRmax(t,
                            0.003,
                            gw,
                            policy0,
                            maxiter=1000,
                            show=True,
                            resample_epsilon=0.0,
                            rmax=1000)
    pi = [gw.linear_policy(w0, s) for s in range(gw.nstates)]
    gw.set_arrows(pi)
    gw.background()
    gw.mainloop()
示例#2
0
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: TEST_GUI.PY
Date: Wednesday, June  6 2012
Description: Code to test gui changes.
"""

from gridworld.gridworldgui import GridworldGui
from lspi import LSPI
import numpy as np

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[])
t = gw.trace(10000)
z = np.zeros(gw.nfeatures())
#import pdb
#pdb.set_trace()
#w = LSPI(t,0.0001,gw,z)
print(gw.phi(0,0))
print(gw.phi(0,1))
print(w)

示例#3
0
文件: test_lspi.py 项目: stober/lspi
    for s in gw.states:
        for a in gw.actions:
            data.append(gw.phi(s,a))
    print find_duplicates(data).next() # stop iteration?

if test_rmax:
    gw = GridworldGui(nrows = 5, ncols = 5, endstates = [0], walls = [])
    
    # try:
    #     raise ValueError # for new trace
    #     t = pickle.load(open("rmax_trace.pck"))
    # except:
    #     t = gw.trace(100, show = True)
    #     pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    policy0 = np.zeros(gw.nfeatures())
    t = []
    # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed.
    w0, weights0 = LSPIRmax(t, 0.003, gw, policy0, maxiter = 1000, show = True, resample_epsilon = 0.0, rmax = 1000)
    pi = [gw.linear_policy(w0,s) for s in range(gw.nstates)]
    gw.set_arrows(pi)
    gw.background()
    gw.mainloop()

if test_walls:
    gw = GridworldGui(nrows=5,ncols=5,endstates= [0], walls=[(1,1),(1,2),(1,3),(2,1),(2,2),(2,3),(3,1),(3,2),(3,3)])
    try:
        t = pickle.load(open("walls_trace.pck"))
    except:
        t = gw.trace(1000, show=False)
        pickle.dump(t,open("walls_trace.pck","w"),pickle.HIGHEST_PROTOCOL)
示例#4
0
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: TEST_GUI.PY
Date: Wednesday, June  6 2012
Description: Code to test gui changes.
"""

from gridworld.gridworldgui import GridworldGui
from lspi import LSPI
import numpy as np

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[])
t = gw.trace(10000)
z = np.zeros(gw.nfeatures())
#import pdb
#pdb.set_trace()
#w = LSPI(t,0.0001,gw,z)
print gw.phi(0,0)
print gw.phi(0,1)
print w