Ejemplo n.º 1
0
def main():
    # env = gym.make("FrozenLake-v0", is_slippery=False)  # 0 left, 1 down, 2 right, 3 up
    # env = FrozenLakeWapper(env)

    env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    env = CliffWalkingWapper(env)

    agent = SarsaAgent(obs_n=env.observation_space.n,
                       act_n=env.action_space.n,
                       learning_rate=0.1,
                       gamma=0.9,
                       e_greed=0.1)

    is_render = False
    for episode in range(500):
        ep_reward, ep_steps = run_episode(env, agent, is_render)
        print('Episode %s: steps = %s , reward = %.1f' %
              (episode, ep_steps, ep_reward))

        # 每隔20个episode渲染一下看看效果
        if episode % 20 == 0:
            is_render = True
        else:
            is_render = False
    # 训练结束,查看算法效果
    test_episode(env, agent)
Ejemplo n.º 2
0
def main():
    # 初始化 环境
    # 冰湖环境
    # env = gym.make("FrozenLake-v0", is_slippery=False)  # 0 left, 1 down, 2 right, 3 up
    # env = FrozenLakeWapper(env)

    # 悬崖环境
    env = gym.make("CliffWalking-v0")
    env = CliffWalkingWapper(env)

    # 初始化 Agent
    agent = SarsaAgent(obs_n=env.observation_space.n,
                       act_n=env.action_space.n,
                       learning_rate=0.1,
                       gamma=0.9,
                       e_greed=0.1)

    # 开始训练
    render = False
    for episode in range(500):
        ep_steps, ep_reward = run_episode(env, agent, render)
        print('Episode %s: steps = %s , reward = %.1f' %
              (episode, ep_steps, ep_reward))
        # 每隔 20 个 episode 看一下效果
        if episode % 20 == 0:
            render = True
        else:
            render = False

    # 训练结束,看一下效果
    test_episode(env, agent)
Ejemplo n.º 3
0
 def choose_action(self, env, state):
     if self.option is None or self.option.is_terminated(env, state):
         self.option = None
         while self.option is None:
             opti = SarsaAgent.choose_action(self, env, state)[0]
             self.option = self.options[opti]
             if not self.option.can_initiate(env, state):
                 self.qTable[state[:self.stateDim]+(opti,)] = float('-inf')
                 self.option = None
             #print zip(*np.where(np.isinf(self.qTable)))
             assert not np.all(np.isinf(self.qTable[state[:self.stateDim]])), (self, state)
             #else:
             #    pass
             #    #print 'started option', self.option
     return self.option.choose_action(env, state)
Ejemplo n.º 4
0
def main():
    env = gym.make("FrozenLake-v0",
                   is_slippery=False)  # 0 left, 1 down, 2 right, 3 up
    env = FrozenLakeWapper(env)

    agent = SarsaAgent(obs_n=env.observation_space.n,
                       act_n=env.action_space.n,
                       learning_rate=0.1,
                       gamma=0.9,
                       e_greed=0.1)

    for episode in range(500):
        ep_reward, ep_steps = run_episode(env, agent)
        print('Episode %s: steps = %s , reward = %.1f' %
              (episode, ep_steps, ep_reward))

    # 训练结束,查看算法效果
    test_episode(env, agent)
Ejemplo n.º 5
0
  Implementation of the interaction between the Gambler's problem environment
  and the Monte Carlo agent using RLGlue.
"""
from rl_glue import RLGlue
from env import WindygridEnvironment
from agent import SarsaAgent
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":
    max_steps = 8000
    num_runs = 1

    # Create and pass agent and environment objects to RLGlue
    environment = WindygridEnvironment()
    agent = SarsaAgent()
    rlglue = RLGlue(environment, agent)
    del agent, environment  # don't use these anymore
    for run in range(num_runs):
        episode=[]
        time_step=[]
        rlglue.rl_init()
        while True:
            rlglue.rl_episode()
            time_step.append(rlglue.num_steps())
            episode.append(rlglue.num_episodes())
            if rlglue.num_steps() > 8000:
                break

    plt.plot(time_step,episode,label="8 actions")
    plt.xticks([0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000])
Ejemplo n.º 6
0
 def feedback(self, state, action, reward, state2, action2 = None, env=None):
     assert self.option is not None
     opti = np.array([self.options.index(self.option)])
     SarsaAgent.feedback(self, state, opti,
             reward, state2, opti, env)
Ejemplo n.º 7
0
 def __init__(self, state_desc, action_desc):
     self.options = [DoorOption(), KeyOption(), LockOption()]
     SarsaAgent.__init__(self, state_desc, (len(self.options),))
     self.stateDim = len(state_desc)  # dims in state vector that we care about
     self.option = None
Ejemplo n.º 8
0
import gym
from agent import SarsaAgent

EPISODES = 1000000
EPOCH = 10000

if __name__ == '__main__':
	env = gym.make('Blackjack-v0')
	agent = SarsaAgent(env.observation_space, env.action_space)
	
	win = 0
	loss = 0

	for episode in range(1, EPISODES + 1):
		done = 0
		current_state = env.reset()
		current_action = agent.choose_action(current_state)
		reward = 0

		while not done:
			next_state, reward, done, _ = env.step(current_action)
			next_action = agent.choose_action(next_state)

			agent.update(current_state, current_action, reward, next_state, next_action)

			current_state = next_state
			current_action = next_action

		# Stats computation
		if reward > 0:
			win = win + 1;
Ejemplo n.º 9
0
def test_sarsa():
    env = Maze()
    agent = SarsaAgent(act_n=4)
    rs = sarsa_demo(env, agent, 2000)
    plt.plot(range(2000), rs), plt.grid(), plt.show()
Ejemplo n.º 10
0
def test_gym():
    env = gym.make('MountainCar-v0')
    linear_func = LinearModel(feat_n, get_feature)
    agent = SarsaAgent(act_n=3, linear_func=linear_func)
    rs = gym_demo(env, agent, 500)
    plt.plot(range(500), rs), plt.grid(), plt.show()
Ejemplo n.º 11
0
		help='Number of seeds to average the learnt values')
	args = parser.parse_args()
	return args

if __name__ == '__main__':
	os.makedirs('plots', exist_ok=True)

	args = get_args()

	if not args.baseline and not args.kingmoves and not args.stochastic:
		print('Please pass the correct argument flag')

	if args.baseline:
		print('----------------- BaseLine -----------------')
		env = WindyGridWorld()
		agent = SarsaAgent(env, alpha=0.5, epsilon=0.1, save_plot_path='plots/baseline.png')
		agent.learn(num_seed_runs = args.num_seed_runs)

	if args.kingmoves:
		print('----------------- KingMoves -----------------')
		env = WindyGridWorldwithKingMoves()
		agent = SarsaAgent(env, alpha=0.5, epsilon=0.1, save_plot_path='plots/kingmoves.png')
		agent.learn(num_seed_runs = args.num_seed_runs)

	if args.stochastic:
		print('----------------- Stochastic with KingMoves -----------------')
		env = StochasticWindyGridWorld()
		agent = SarsaAgent(env, alpha=0.5, epsilon=0.1, save_plot_path='plots/stochastic.png')
		agent.learn(num_seed_runs = args.num_seed_runs)