예제 #1
0
from rl_glue import RLGlue
from windy_env import WindyEnvironment
from sarsa_agent import SarsaAgent
import numpy as np
import matplotlib.pyplot as plt

max_steps = 8000
steps = 0
episodes = 0

ep_list = []
step_list = []

environment = WindyEnvironment()
agent = SarsaAgent()
rl = RLGlue(environment, agent)
rl.rl_init()
while steps < max_steps:
    rl.rl_episode(max_steps)
    steps = rl.num_steps()
    episodes = rl.num_episodes()
    # print(steps, episodes)

    ep_list.append(episodes)
    step_list.append(steps)

plt.xlabel('Time steps')
plt.ylabel('Episodes')
plt.plot(step_list, ep_list)
plt.show()
예제 #2
0
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":
    max_steps = 8000
    num_runs = 1

    # Create and pass agent and environment objects to RLGlue
    environment = WindygridEnvironment()
    agent = SarsaAgent()
    rlglue = RLGlue(environment, agent)
    del agent, environment  # don't use these anymore
    for run in range(num_runs):
        episode=[]
        time_step=[]
        rlglue.rl_init()
        while True:
            rlglue.rl_episode()
            time_step.append(rlglue.num_steps())
            episode.append(rlglue.num_episodes())
            if rlglue.num_steps() > 8000:
                break

    plt.plot(time_step,episode,label="8 actions")
    plt.xticks([0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000])
    plt.xlabel('Time steps')
    plt.ylabel('Episode', rotation=90)
    plt.legend(loc=2)
    plt.show()
        # save average value function numpy object, to be used by plotting script
예제 #3
0
from rl_glue import RLGlue
from windy_env import WindyEnvironment
from n_step_sarsa_agent import SarsaAgent
import numpy as np
import time
import matplotlib.pyplot as plt

if __name__ == "__main__":
    start_time = time.time()
    max_steps = 8000

    # Create and pass agent and environment objects to RLGlue
    environment = WindyEnvironment()
    agent = SarsaAgent()
    rlglue = RLGlue(environment, agent)
    del agent, environment  # don't use these anymore
    rlglue.rl_init()
    L1 = []
    L2 = []
    n = rlglue.rl_agent_message('n')
    a = rlglue.rl_agent_message('a')
    while rlglue.num_steps() < max_steps:
        L1.append(rlglue.num_steps())
        rlglue.rl_episode(10000)
        episodes = rlglue.num_episodes()
        L2.append(episodes)
    plt.title(str(n) + '-step sarsa with ' + str(a) + " actions")
    plt.plot(L1, L2)
    plt.show()