Ejemplo n.º 1
0
def her_experiment():
    batch_size = 256
    discount_factor = 0.8
    learn_rate = 1e-3
    num_hidden = 128
    num_episodes = 2
    epochs = 200
    training_steps = 10
    memory_size = 100000
    # her = False
    # seeds = [42, 30, 2,19,99]  # This is not randomly chosen
    seeds = [42, 30, 2, 19, 99]
    shape = [30, 30]
    targets = lambda x, y: [0, x * y - 1, x - 1, (y - 1) * x]
    env = GridworldEnv(shape=shape, targets=targets(*shape))

    # functions for grid world
    def sample_goal():
        return np.random.choice(env.targets, 1)

    extract_goal = lambda state: np.reshape(np.array(np.argmax(state)), -1)

    def calc_reward(state, action, goal):
        if state == goal:
            return 0.0
        else:
            return -1.0
        # # maze
        #     def sample_goal():
        #         return env.maze.end_pos
        #     extract_goal = lambda state: np.reshape(np.array(np.argmax(state)),-1)
        #     def calc_reward(state, action, goal):
        #         if state == goal:
        #             return 0.0
        #         else:
        #             return -1.0

    means = []
    x_epochs = []
    l_stds = []
    h_stds = []
    for her in [True, False]:
        episode_durations_all = []
        for seed in seeds:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            env.seed(seed)
            print(env.reset())
            memory = ReplayMemory(memory_size)
            if her:
                # model = QNetwork(env.observation_space.shape[0]+2, num_hidden, env.action_space.n)
                model = QNetwork(2 * env.observation_space.n, num_hidden,
                                 env.action_space.n)
                episode_durations, episode_rewards = run_her_episodes(
                    train,
                    model,
                    memory,
                    env,
                    num_episodes,
                    training_steps,
                    epochs,
                    batch_size,
                    discount_factor,
                    learn_rate,
                    sample_goal,
                    extract_goal,
                    calc_reward,
                    use_her=True)
            else:
                model = QNetwork(env.observation_space.n, num_hidden,
                                 env.action_space.n)
                episode_durations, episode_rewards = run_her_episodes(
                    train,
                    model,
                    memory,
                    env,
                    num_episodes,
                    training_steps,
                    epochs,
                    batch_size,
                    discount_factor,
                    learn_rate,
                    sample_goal,
                    extract_goal,
                    calc_reward,
                    use_her=False)

            episode_durations_all.append(
                loop_environments.smooth(episode_durations, 10))
        mean = np.mean(episode_durations_all, axis=0)
        means.append(mean)
        std = np.std(episode_durations_all, ddof=1, axis=0)
        l_stds.append(mean - std)
        h_stds.append(mean + std)
        x_epochs.append(list(range(len(mean))))
        # print(len(mean),mean,std)
    line_plot_var(x_epochs, means, l_stds, h_stds, "Epoch", "Duration",
                  ["HindsightReplay", "RandomReplay"],
                  "Episode duration per epoch", ["orange", "blue"])
    name = "her_" + str(shape)
    file_name = os.path.join("./results", name)

    with open(file_name + ".pkl", "wb") as f:
        pickle.dump((x_epochs, means, l_stds, h_stds), f)
Ejemplo n.º 2
0
def main():
    print("Running DQN")

    if config.env == "GridWorldEnv":
        print("Playing: ", config.env)
        env = GridworldEnv()
    else:
        env_name = config.env
        print("Playing:", env_name)
        env = gym.make(env_name)

    # not 100 % sure this will work for all envs
    obs_shape = env.observation_space.shape
    num_actions = env.action_space.n
    assert len(
        obs_shape) <= 1, "Not yet compatible with multi-dim observation space"
    if len(obs_shape) > 0:
        obs_size = obs_shape[0]
    else:
        obs_size = 1

    num_episodes = config.n_episodes
    batch_size = config.batch_size
    discount_factor = config.discount_factor
    learn_rate = config.learn_rate
    seed = config.seed
    num_hidden = config.num_hidden
    min_eps = config.min_eps
    max_eps = config.max_eps
    anneal_time = config.anneal_time
    clone_interval = config.clone_interval
    replay = (config.replay_off == False)
    clipping = (config.clipping_off == False)

    if config.memory_size is None:
        memory_size = 10 * batch_size
    else:
        memory_size = config.memory_size

    if not replay and (batch_size != 1 or memory_size != 1):
        print("Replay is turned off: adjusting memory and batch size to 1")
        batch_size = 1
        memory_size = 1

    memory = ReplayMemory(memory_size)

    # We will seed the algorithm (before initializing QNetwork!) for reproducibility
    random.seed(seed)
    torch.manual_seed(seed)
    env.seed(seed)

    Q_net = QNetwork(obs_size, num_actions, num_hidden=num_hidden)
    policy = EpsilonGreedyPolicy(Q_net, num_actions)
    episode_durations, losses, max_qs = run_episodes(
        train, Q_net, policy, memory, env, num_episodes, batch_size,
        discount_factor, learn_rate, clone_interval, min_eps, max_eps,
        anneal_time, clipping)

    plot_smooth(episode_durations, 10, show=True)

    # This just for now to see results quick. TODO: make nice plot function to test/compare multiple settings
    plt.plot(losses)
    plt.title(
        f"{config.env}, lr={learn_rate}, replay={replay}, clone_interval={clone_interval}"
    )
    plt.ylabel("Loss")
    plt.xlabel("Episode")
    plt.show()

    plt.plot(max_qs)
    if clipping:
        plt.axhline(y=1. / (1 - discount_factor), color='r', linestyle='-')
    plt.title(
        f"{config.env}, lr={learn_rate}, replay={replay}, clone_interval={clone_interval}"
    )
    plt.ylabel("max |Q|")
    plt.xlabel("Episode")
    plt.show()