def setUpClass(cls): cls.batch_size = 64 cls.thread_pool = 4 cls.max_episode_steps = 1000 def env_fn(): return make("Pendulum-v0") cls.continuous_sample_env = env_fn() cls.continuous_envs = MultiThreadEnv( env_fn=env_fn, batch_size=cls.batch_size, max_episode_steps=cls.max_episode_steps) def env_fn(): return make("CartPole-v0") cls.discrete_sample_env = env_fn() cls.discrete_envs = MultiThreadEnv( env_fn=env_fn, batch_size=cls.batch_size, max_episode_steps=cls.max_episode_steps)
def explorer(global_rb, queue, trained_steps, is_training_done, lock, env_fn, policy_fn, set_weights_fn, noise_level, n_env=64, n_thread=4, buffer_size=1024, episode_max_steps=1000, gpu=0): """ Collect transitions and store them to prioritized replay buffer. :param global_rb (multiprocessing.managers.AutoProxy[PrioritizedReplayBuffer]): Prioritized replay buffer sharing with multiple explorers and only one learner. This object is shared over processes, so it must be locked when trying to operate something with `lock` object. :param queue (multiprocessing.Queue): A FIFO shared with the `learner` and `evaluator` to get the latest network weights. This is process safe, so you don't need to lock process when use this. :param trained_steps (multiprocessing.Value): Number of steps to apply gradients. :param is_training_done (multiprocessing.Event): multiprocessing.Event object to share the status of training. :param lock (multiprocessing.Lock): multiprocessing.Lock to lock other processes. :param env_fn (function): Method object to generate an environment. :param policy_fn (function): Method object to generate an explorer. :param set_weights_fn (function): Method object to set network weights gotten from queue. :param noise_level (float): Noise level for exploration. For epsilon-greedy policy like DQN variants, this will be epsilon, and if DDPG variants this will be variance for Normal distribution. :param n_env (int): Number of environments to distribute. If this is set to be more than 1, `MultiThreadEnv` will be used. :param n_thread (int): Number of thread used in `MultiThreadEnv`. :param buffer_size (int): Size of local buffer. If this is filled with transitions, add them to `global_rb` :param episode_max_steps (int): Maximum number of steps of an episode. :param gpu (int): GPU id. If this is set to -1, then this process uses only CPU. """ import_tf() logger = logging.getLogger("tf2rl") if n_env > 1: envs = MultiThreadEnv(env_fn=env_fn, batch_size=n_env, thread_pool=n_thread, max_episode_steps=episode_max_steps) env = envs._sample_env else: env = env_fn() policy = policy_fn(env=env, name="Explorer", memory_capacity=global_rb.get_buffer_size(), noise_level=noise_level, gpu=gpu) kwargs = get_default_rb_dict(buffer_size, env) if n_env > 1: kwargs["env_dict"]["priorities"] = {} local_rb = ReplayBuffer(**kwargs) local_idx = np.arange(buffer_size).astype(np.int) if n_env == 1: s = env.reset() episode_steps = 0 total_reward = 0. total_rewards = [] start = time.time() n_sample, n_sample_old = 0, 0 while not is_training_done.is_set(): if n_env == 1: n_sample += 1 episode_steps += 1 a = policy.get_action(s) s_, r, done, _ = env.step(a) done_flag = done if episode_steps == env._max_episode_steps: done_flag = False total_reward += r local_rb.add(obs=s, act=a, rew=r, next_obs=s_, done=done_flag) s = s_ if done or episode_steps == episode_max_steps: s = env.reset() total_rewards.append(total_reward) total_reward = 0 episode_steps = 0 else: n_sample += n_env obses = envs.py_observation() actions = policy.get_action(obses, tensor=True) next_obses, rewards, dones, _ = envs.step(actions) td_errors = policy.compute_td_error(states=obses, actions=actions, next_states=next_obses, rewards=rewards, dones=dones) local_rb.add(obs=obses, act=actions, next_obs=next_obses, rew=rewards, done=dones, priorities=np.abs(td_errors + 1e-6)) # Periodically copy weights of explorer if not queue.empty(): set_weights_fn(policy, queue.get()) # Add collected experiences to global replay buffer if local_rb.get_stored_size() == buffer_size: samples = local_rb._encode_sample(local_idx) if n_env > 1: priorities = np.squeeze(samples["priorities"]) else: td_errors = policy.compute_td_error( states=samples["obs"], actions=samples["act"], next_states=samples["next_obs"], rewards=samples["rew"], dones=samples["done"]) priorities = np.abs(np.squeeze(td_errors)) + 1e-6 lock.acquire() global_rb.add(obs=samples["obs"], act=samples["act"], rew=samples["rew"], next_obs=samples["next_obs"], done=samples["done"], priorities=priorities) lock.release() local_rb.clear() msg = "Grad: {0: 6d}\t".format(trained_steps.value) msg += "Samples: {0: 7d}\t".format(n_sample) msg += "TDErr: {0:.5f}\t".format(np.average(priorities)) if n_env == 1: ave_rew = (0 if len(total_rewards) == 0 else sum(total_rewards) / len(total_rewards)) msg += "AveEpiRew: {0:.3f}\t".format(ave_rew) total_rewards = [] msg += "FPS: {0:.2f}".format( (n_sample - n_sample_old) / (time.time() - start)) logger.info(msg) start = time.time() n_sample_old = n_sample