def __init__(self, env_name, env, num_episodes, n_step=1, discount_factor=0.95, learning_rate=0.01, start_learning_rate=0.1, start_epsilon=1.0, decay_rate=0.001, action_space_n=None, render_env=False, make_checkpoint=False, is_state_box=False, batch_size=25, memory_capacity=1000): self.start_time = 0 self.env_name = env_name self.env = env self.MAX_STEPS = 200 self.num_episodes = num_episodes self.start_learning_rate = start_learning_rate self.learning_rate = learning_rate self.discount_factor = discount_factor self.start_epsilon = start_epsilon self.epsilon = 0 self.decay_rate = decay_rate self.make_checkpoint = make_checkpoint self.n_step = n_step self.dir_location = "/home/dsalwala/NUIG/Thesis/rl-algos/data" self.is_state_box = is_state_box self.action_space_n = action_space_n self.render_env = render_env self.batch_size = batch_size self.memory_capacity = memory_capacity self.stats = plotting.EpisodeStats( episode_lengths=np.zeros(num_episodes), episode_rewards=np.zeros(num_episodes)) if action_space_n is None: self.nA = env.action_space.n else: self.nA = action_space_n if self.is_state_box: self.nS = self.env.observation_space.shape[0] else: self.nS = 1
# Select best action to perform in a current state action = np.argmax(Q[state]) # Perform an action an observe how environment acted in response next_state, reward, terminated, info = env.step(action) # Update current state state = next_state # Calculate number of wins over episodes if terminated and reward == 1.0: break # Load a Windy GridWorld environment environment = CliffWalkingEnv() agent = QLearningAgent("CliffWalking-v0", environment, 1000, start_learning_rate=0.1, start_epsilon=1.0, discount_factor=0.95, decay_rate=0.001, make_checkpoint=True) # agent.train() Q, rewards, episode_len = agent.load("/home/dsalwala/NUIG/Thesis/rl-algos/data/CliffWalking-v0_1000.npy") stats = plotting.EpisodeStats( episode_lengths=episode_len, episode_rewards=rewards) # Search for a Q values # Q, stats = agent.q_table, agent.stats play_episode(environment, Q) plotting.plot_episode_stats(stats)