コード例 #1
0
ファイル: test_lspi.py プロジェクト: stober/lspi
    pi = [gw.linear_policy(w0,s) for s in range(gw.nstates)]
    gw.set_arrows(pi)    
    gw.background()
    gw.mainloop()

if test_realpca:
    import pdb
    #pdb.set_trace()
    #endstates = [32, 2016, 1024, 1040, 1056, 1072]
    #endstates = [16,256,264,272,280,496]
    endstates = [272] # [16]
    #endstates = [0] # TODO find proper endstates
    # ogw = ObserverGridworldGui("/Users/stober/wrk/lspi/bin/16/5comp.npy", "/Users/stober/wrk/lspi/bin/16/states.npy", endstates = endstates, walls=None)
    # just isnt' working for state 16
    # nrbf 40 works best
    ogw = RBFObserverGridworldGui("/Users/stober/wrk/lspi/bin/16/20comp.npy", "/Users/stober/wrk/lspi/bin/16/states.npy", endstates = endstates, walls=None, nrbf=80)
    # ogw.load_features('rbf_obs_features.pck')
    # ogw = ObserverGridworldGui("/Users/stober/wrk/lspi/bin/32/observations4.npy", "/Users/stober/wrk/lspi/bin/32/states.npy", endstates = endstates, walls=None)
    # ogw = GridworldGui(nrows=16,ncols=32,endstates = endstates, walls=[])
    try:
        #raise Exception # force a trace regeneration
        t = pickle.load(open(workspace + "/traces/complete_trace.pck"))
    except:
        pass
        #t = ogw.trace(100000)
        #pickle.dump(t, open(workspace + "/traces/real_pca_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    #pdb.set_trace()
    #print ogw.phi(0,0)
    #raise Exception
コード例 #2
0
ファイル: test_lspi.py プロジェクト: QueensGambit/lspi
    gw.background()
    gw.mainloop()

if test_realpca:
    import pdb
    #pdb.set_trace()
    #endstates = [32, 2016, 1024, 1040, 1056, 1072]
    #endstates = [16,256,264,272,280,496]
    endstates = [272]  # [16]
    #endstates = [0] # TODO find proper endstates
    # ogw = ObserverGridworldGui("/Users/stober/wrk/lspi/bin/16/5comp.npy", "/Users/stober/wrk/lspi/bin/16/states.npy", endstates = endstates, walls=None)
    # just isnt' working for state 16
    # nrbf 40 works best
    ogw = RBFObserverGridworldGui("/Users/stober/wrk/lspi/bin/16/20comp.npy",
                                  "/Users/stober/wrk/lspi/bin/16/states.npy",
                                  endstates=endstates,
                                  walls=None,
                                  nrbf=80)
    # ogw.load_features('rbf_obs_features.pck')
    # ogw = ObserverGridworldGui("/Users/stober/wrk/lspi/bin/32/observations4.npy", "/Users/stober/wrk/lspi/bin/32/states.npy", endstates = endstates, walls=None)
    # ogw = GridworldGui(nrows=16,ncols=32,endstates = endstates, walls=[])
    try:
        #raise Exception # force a trace regeneration
        t = pickle.load(open(workspace + "/traces/complete_trace.pck"))
    except:
        pass
        #t = ogw.trace(100000)
        #pickle.dump(t, open(workspace + "/traces/real_pca_trace.pck","w"), pickle.HIGHEST_PROTOCOL)

    #pdb.set_trace()
    #print ogw.phi(0,0)
コード例 #3
0
ファイル: 2d_rotation.py プロジェクト: DaomingLyu/lspi
if __name__ == '__main__':

    workspace = "{0}/wrk/lspi/bin".format(os.environ['HOME'])

    if False:
        # fix isomap issue on bad matches
        ematrix = pickle.load(open('ematrix0.pck'))
        y, s, adj = isomap(ematrix)

    if True:

        endstates = [272]  # [16]

        ogw = RBFObserverGridworldGui(
            "/Users/stober/wrk/lspi/bin/16/20comp.npy",
            "/Users/stober/wrk/lspi/bin/16/states.npy",
            endstates=endstates,
            walls=None,
            nrbf=80)

        # ogw = RBFObserverGridworld("/Users/stober/wrk/lspi/bin/16/20comp.npy", "/Users/stober/wrk/lspi/bin/16/states.npy", endstates = endstates, walls=None, nrbf=80)

        t = pickle.load(open(workspace + "/traces/complete_trace.pck"))

        old_endstates = [16, 256, 264, 272, 280, 496]
        t = modify_endstates(t, old_endstates, endstates, action_costs=True)

        policy0 = np.zeros(ogw.nfeatures())
        w0, weights0 = LSPI(t,
                            0.005,
                            ogw,
                            policy0,