Esempio n. 1
0
def run_evaluation(config, path, episodes=50):
    env = config["env"]
    agent = NECAgent(config)
    agent.nec_net.load_state_dict(torch.load(path))
    env.eval()
    agent.eval()

    rewards = []

    for ep in range(1, episodes + 1):
        obs, reward_sum = env.reset(), 0

        while True:
            env.render(mode='rgb-array')
            obs = torch.from_numpy(np.float32(obs))
            action = agent.step(obs)
            next_obs, reward, done, info = env.step(action)
            reward_sum += reward
            obs = next_obs

            if done:
                if config['env_name'].startswith('CartPole'):
                    reward_sum -= reward

                rewards.append(reward_sum)
                break

    plt.plot(range(1, episodes + 1), rewards)
    plt.savefig(f"eval_{config['exp_name']}.png")
Esempio n. 2
0
def run_training(config, return_agent=False):
    env = config["env"]
    env.train()
    agent = NECAgent(config)

    done = True
    epsilon = 1
    for t in tqdm(range(1, config["max_steps"] + 1)):
        if done:
            obs, done = env.reset(), False
            agent.new_episode()

        if config["epsilon_anneal_start"] < t <= config["epsilon_anneal_end"]:
            epsilon -= (config["initial_epsilon"] - config["final_epsilon"]
                        ) / (config["epsilon_anneal_end"] -
                             config["epsilon_anneal_start"])
            agent.set_epsilon(epsilon)

        # env.render()
        if type(obs) is np.ndarray:
            obs = torch.from_numpy(np.float32(obs))
        action = agent.step(obs.to(config['device']))
        next_obs, reward, done, info = env.step(action)
        solved = agent.update((reward, done))

        if solved:
            return

        obs = next_obs

        if t >= config["start_learning_step"]:
            if t % config["replay_frequency"] == 0:
                agent.optimize()

            if t % config["eval_frequency"] == 0:
                # agent.eval()
                # # evaluate agent here #
                path = f'{os.getcwd()}/pong/trained_agents/nec_{agent.exp_name}_{t // config["eval_frequency"]}.pth'
                torch.save(agent.nec_net.state_dict(), path)
                # agent.train()

    if return_agent:
        return agent