示例#1
0
test_rmax = False
test_realpca = False  #True
test_alias = False
test_complete = False


def compare_all_phi(gw):
    data = []
    for s in gw.states:
        for a in gw.actions:
            data.append(gw.phi(s, a))
    print(next(find_duplicates(data)))  # stop iteration?


if test_rmax:
    gw = GridworldGui(nrows=5, ncols=5, endstates=[0], walls=[])

    # try:
    #     raise ValueError # for new trace
    #     t = pickle.load(open("rmax_trace.pck"))
    # except:
    #     t = gw.trace(100, show = True)
    #     pickle.dump(t, open("rmax_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    policy0 = np.zeros(gw.nfeatures())
    t = []
    # TODO - The tolerances for lsqr need to be related to the tolerances for the policy. Otherwise the number of iterations will be far larger than needed.
    w0, weights0 = LSPIRmax(t,
                            0.003,
                            gw,
                            policy0,
示例#2
0
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: TEST_GUI.PY
Date: Wednesday, June  6 2012
Description: Code to test gui changes.
"""

from gridworld.gridworldgui import GridworldGui
from lspi import LSPI
import numpy as np

endstates = [32, 2016, 1024, 1040, 1056, 1072]
gw = GridworldGui(nrows=32,ncols=64,endstates=endstates,walls=[])
t = gw.trace(10000)
z = np.zeros(gw.nfeatures())
#import pdb
#pdb.set_trace()
#w = LSPI(t,0.0001,gw,z)
print(gw.phi(0,0))
print(gw.phi(0,1))
print(w)