예제 #1
0
파일: LSPI_SARSA.py 프로젝트: irenge/RLPy
class LSPI_SARSA(Agent):
    def __init__(self,representation,policy,domain,logger, lspi_iterations = 5, sample_window = 100, epsilon = 1e-3, re_iterations = 100,initial_alpha =.1, lambda_=0,alpha_decay_mode ='dabney', boyan_N0 = 1000):
        self.SARSA = SARSA(representation, policy, domain,logger, initial_alpha, lambda_,alpha_decay_mode, boyan_N0)
        self.LSPI = LSPI(representation,policy,domain,logger, lspi_iterations, sample_window, epsilon, re_iterations)
        super(LSPI_SARSA,self).__init__(representation,policy,domain,logger)
    def learn(self,s,a,r,ns,na,terminal):
        self.LSPI.process(s,a,r,ns,na,terminal)        
        if self.LSPI.samples_count+1 % self.LSPI.steps_between_LSPI == 0:
            self.LSPI.representationExpansionLSPI()
            if terminal:
                self.episodeTerminated()
        else:
            self.SARSA.learn(s,a,r,ns,na,terminal)
예제 #2
0
파일: LSPI_SARSA.py 프로젝트: irenge/RLPy
 def __init__(self,representation,policy,domain,logger, lspi_iterations = 5, sample_window = 100, epsilon = 1e-3, re_iterations = 100,initial_alpha =.1, lambda_=0,alpha_decay_mode ='dabney', boyan_N0 = 1000):
     self.SARSA = SARSA(representation, policy, domain,logger, initial_alpha, lambda_,alpha_decay_mode, boyan_N0)
     self.LSPI = LSPI(representation,policy,domain,logger, lspi_iterations, sample_window, epsilon, re_iterations)
     super(LSPI_SARSA,self).__init__(representation,policy,domain,logger)
예제 #3
0
# import sampler
# 
# samples = sampler.sample(50)
# print(samples)

import LSPI
import matplotlib.pyplot as plt
import sampler

pi,distances = LSPI.run_LSPI()

plt.plot(distances)
plt.show()

sampler.use_policy(pi)