Пример #1
0
    def setUpClass(cls):
        cls.batch_size = 64
        cls.thread_pool = 4
        cls.max_episode_steps = 1000

        def env_fn():
            return make("Pendulum-v0")

        cls.continuous_sample_env = env_fn()
        cls.continuous_envs = MultiThreadEnv(
            env_fn=env_fn,
            batch_size=cls.batch_size,
            max_episode_steps=cls.max_episode_steps)

        def env_fn():
            return make("CartPole-v0")

        cls.discrete_sample_env = env_fn()
        cls.discrete_envs = MultiThreadEnv(
            env_fn=env_fn,
            batch_size=cls.batch_size,
            max_episode_steps=cls.max_episode_steps)
Пример #2
0
def explorer(global_rb,
             queue,
             trained_steps,
             is_training_done,
             lock,
             env_fn,
             policy_fn,
             set_weights_fn,
             noise_level,
             n_env=64,
             n_thread=4,
             buffer_size=1024,
             episode_max_steps=1000,
             gpu=0):
    """
    Collect transitions and store them to prioritized replay buffer.

    :param global_rb (multiprocessing.managers.AutoProxy[PrioritizedReplayBuffer]):
        Prioritized replay buffer sharing with multiple explorers and only one learner.
        This object is shared over processes, so it must be locked when trying to
        operate something with `lock` object.
    :param queue (multiprocessing.Queue):
        A FIFO shared with the `learner` and `evaluator` to get the latest network weights.
        This is process safe, so you don't need to lock process when use this.
    :param trained_steps (multiprocessing.Value):
        Number of steps to apply gradients.
    :param is_training_done (multiprocessing.Event):
        multiprocessing.Event object to share the status of training.
    :param lock (multiprocessing.Lock):
        multiprocessing.Lock to lock other processes.
    :param env_fn (function):
        Method object to generate an environment.
    :param policy_fn (function):
        Method object to generate an explorer.
    :param set_weights_fn (function):
        Method object to set network weights gotten from queue.
    :param noise_level (float):
        Noise level for exploration. For epsilon-greedy policy like DQN variants,
        this will be epsilon, and if DDPG variants this will be variance for Normal distribution.
    :param n_env (int):
        Number of environments to distribute. If this is set to be more than 1,
        `MultiThreadEnv` will be used.
    :param n_thread (int):
        Number of thread used in `MultiThreadEnv`.
    :param buffer_size (int):
        Size of local buffer. If this is filled with transitions, add them to `global_rb`
    :param episode_max_steps (int):
        Maximum number of steps of an episode.
    :param gpu (int):
        GPU id. If this is set to -1, then this process uses only CPU.
    """
    import_tf()
    logger = logging.getLogger("tf2rl")

    if n_env > 1:
        envs = MultiThreadEnv(env_fn=env_fn,
                              batch_size=n_env,
                              thread_pool=n_thread,
                              max_episode_steps=episode_max_steps)
        env = envs._sample_env
    else:
        env = env_fn()

    policy = policy_fn(env=env,
                       name="Explorer",
                       memory_capacity=global_rb.get_buffer_size(),
                       noise_level=noise_level,
                       gpu=gpu)

    kwargs = get_default_rb_dict(buffer_size, env)
    if n_env > 1:
        kwargs["env_dict"]["priorities"] = {}
    local_rb = ReplayBuffer(**kwargs)
    local_idx = np.arange(buffer_size).astype(np.int)

    if n_env == 1:
        s = env.reset()
        episode_steps = 0
        total_reward = 0.
        total_rewards = []

    start = time.time()
    n_sample, n_sample_old = 0, 0

    while not is_training_done.is_set():
        if n_env == 1:
            n_sample += 1
            episode_steps += 1
            a = policy.get_action(s)
            s_, r, done, _ = env.step(a)
            done_flag = done
            if episode_steps == env._max_episode_steps:
                done_flag = False
            total_reward += r
            local_rb.add(obs=s, act=a, rew=r, next_obs=s_, done=done_flag)

            s = s_
            if done or episode_steps == episode_max_steps:
                s = env.reset()
                total_rewards.append(total_reward)
                total_reward = 0
                episode_steps = 0
        else:
            n_sample += n_env
            obses = envs.py_observation()
            actions = policy.get_action(obses, tensor=True)
            next_obses, rewards, dones, _ = envs.step(actions)
            td_errors = policy.compute_td_error(states=obses,
                                                actions=actions,
                                                next_states=next_obses,
                                                rewards=rewards,
                                                dones=dones)
            local_rb.add(obs=obses,
                         act=actions,
                         next_obs=next_obses,
                         rew=rewards,
                         done=dones,
                         priorities=np.abs(td_errors + 1e-6))

        # Periodically copy weights of explorer
        if not queue.empty():
            set_weights_fn(policy, queue.get())

        # Add collected experiences to global replay buffer
        if local_rb.get_stored_size() == buffer_size:
            samples = local_rb._encode_sample(local_idx)
            if n_env > 1:
                priorities = np.squeeze(samples["priorities"])
            else:
                td_errors = policy.compute_td_error(
                    states=samples["obs"],
                    actions=samples["act"],
                    next_states=samples["next_obs"],
                    rewards=samples["rew"],
                    dones=samples["done"])
                priorities = np.abs(np.squeeze(td_errors)) + 1e-6
            lock.acquire()
            global_rb.add(obs=samples["obs"],
                          act=samples["act"],
                          rew=samples["rew"],
                          next_obs=samples["next_obs"],
                          done=samples["done"],
                          priorities=priorities)
            lock.release()
            local_rb.clear()

            msg = "Grad: {0: 6d}\t".format(trained_steps.value)
            msg += "Samples: {0: 7d}\t".format(n_sample)
            msg += "TDErr: {0:.5f}\t".format(np.average(priorities))
            if n_env == 1:
                ave_rew = (0 if len(total_rewards) == 0 else
                           sum(total_rewards) / len(total_rewards))
                msg += "AveEpiRew: {0:.3f}\t".format(ave_rew)
                total_rewards = []
            msg += "FPS: {0:.2f}".format(
                (n_sample - n_sample_old) / (time.time() - start))
            logger.info(msg)

            start = time.time()
            n_sample_old = n_sample