Exemple #1
0
def LSPIRmax(D, epsilon, env, policy0, maxiter=10, resample_size=1000, show=False, resample_epsilon=0.1, rmax=1.0):
    current = policy0
    all_policies = [current]

    iters = 0
    finished = False
    track = TrackKnown(env.nstates, env.nactions, 1)

    # TODO: need to couple track object with sample set more tightly.
    print "Pre uniq: ", len(D)
    D = track.uniq(D)
    print "Post uniq: ", len(D)

    track.init(D) # initialize knowledge

    print "Resample epsilon: ", resample_epsilon

    if show:
        diagnostics = Diagnostics(env)

    while iters < maxiter and not finished:

        all_policies.append(current)

        # A,b,current,info = LSTDQRmax(D, env, current, track, rmax=rmax)
        start_time = time.time()
        A, b, current, info = ParallelLSTDQRmax(D, env, current, track, rmax=rmax)
        end_time = time.time()
        print "Loop time: ", end_time - start_time
        policy = partial(env.epsilon_linear_policy, resample_epsilon, current) # need to detect/escape cycles?
        
        # more trace data
        t = env.trace(1000, policy=policy, reset_on_cycle=False, reset_on_endstate=False, stop_on_cycle=True)
        print "Trace length: ", len(t)
        track.resample(D, t, take_all=False) # adds new samples
        track.diagnostics()

        if show:
            print diagnostics(iters, current, A)
            # for (i,p) in enumerate(all_policies):
            #     print "policy: ", i, la.norm(p - current)

        iters += 1

        print "Iterations: ", iters

        for p in all_policies:
            if la.norm(p - current) < epsilon and track.all_known():  
                finished = True

    return current, all_policies
Exemple #2
0
def LSPIRmax(D, epsilon, env, policy0, method = "dense", maxiter = 10, resample_size = 1000, show = False, resample_epsilon = 0.1, rmax = 1.0):
    current = policy0
    all_policies = [current]

    iters = 0
    finished = False
    track = TrackKnown(D, env.nstates, env.nactions, 1)

    if show:
        diagnostics = Diagnostics(env)

    while iters < maxiter and not finished:

        print "Iterations: ", iters
        all_policies.append(current)

        start_time = time.time()
        if method is "dense":
            A,b,current,info = LSTDQRmax(track, env, current, rmax=rmax)
        elif method is "parallel":
            A,b,current,info = ParallelLSTDQRmax(track, env, current, rmax=rmax)
        else:
            raise ValueError, "Unknown LSTDQ method!"
        end_time = time.time()
        print "Loop time: ", end_time - start_time

        policy = partial(env.epsilon_linear_policy, resample_epsilon, current) # need to detect/escape cycles?
        
        # more trace data
        t = env.trace(1000, policy = policy, reset_on_cycle = False, reset_on_endstate = False, stop_on_cycle=True)
        print "Trace length: ", len(t)
        track.extend(t, take_all=False) # adds new samples
        track.diagnostics()

        if show:
            print diagnostics(iters,current,A)

        iters += 1

        for p in all_policies:
            if la.norm(p - current) < epsilon and track.all_known():  
                finished = True
                print "Finished"

    return current, all_policies
Exemple #3
0
            dict_loop,
            (D[i:j], env, w, 0.0))  # note that damping needs to be zero here
        results.append(r)

    k = len(w)
    A = sp.identity(k, format='csr') * damping
    b = sp_create(k, 1, 'csr')
    for r in results:
        T, t = r.get()
        A = A + T
        b = b + t

    # close out the pool of workers
    pool.close()
    pool.join()

    w, info = solve(A, b, method="spsolve")
    return A, b, w, info


if __name__ == '__main__':

    from gridworld.gridworld8 import SparseGridworld8
    import cPickle as pickle

    gw = SparseGridworld8(nrows=5, ncols=5, endstates=[0], walls=[])
    t = pickle.load(open("/Users/stober/wrk/lspi/bin/rmax_trace.pck"))
    policy0 = np.zeros(gw.nfeatures())
    track = TrackKnown(t, gw.nstates, gw.nactions, 1)
    compare(LSTDQRmax, ParallelLSTDQRmax, track, gw, policy0)