def test_update_envs_env_update(): max_episode_length = 16 env = PointEnv(max_episode_length=max_episode_length) policy = FixedPolicy(env.spec, scripted_actions=[ env.action_space.sample() for _ in range(max_episode_length) ]) tasks = SetTaskSampler(PointEnv) n_workers = 8 workers = WorkerFactory(seed=100, max_episode_length=max_episode_length, n_workers=n_workers) sampler = MultiprocessingSampler.from_worker_factory(workers, policy, env) episodes = sampler.obtain_samples(0, 161, np.asarray(policy.get_param_values()), env_update=tasks.sample(n_workers)) mean_rewards = [] goals = [] for eps in episodes.split(): mean_rewards.append(eps.rewards.mean()) goals.append(eps.env_infos['task'][0]['goal']) assert np.var(mean_rewards) > 0 assert np.var(goals) > 0 with pytest.raises(ValueError): sampler.obtain_samples(0, 10, np.asarray(policy.get_param_values()), env_update=tasks.sample(n_workers + 1)) sampler.shutdown_worker() env.close()
def test_init_with_crashed_worker(): max_episode_length = 16 env = PointEnv() policy = FixedPolicy(env.spec, scripted_actions=[ env.action_space.sample() for _ in range(max_episode_length) ]) tasks = SetTaskSampler(PointEnv) n_workers = 2 workers = WorkerFactory(seed=100, max_episode_length=max_episode_length, n_workers=n_workers) class CrashingPolicy: def reset(self, **kwargs): raise Exception('Intentional subprocess crash') bad_policy = CrashingPolicy() # This causes worker 2 to crash. sampler = MultiprocessingSampler.from_worker_factory( workers, [policy, bad_policy], envs=tasks.sample(n_workers)) episodes = sampler.obtain_samples(0, 160, None) assert sum(episodes.lengths) >= 160 sampler.shutdown_worker() env.close()
def test_obtain_exact_episodes(): max_episode_length = 15 n_workers = 8 env = PointEnv() per_worker_actions = [env.action_space.sample() for _ in range(n_workers)] policies = [ FixedPolicy(env.spec, [action] * max_episode_length) for action in per_worker_actions ] workers = WorkerFactory(seed=100, max_episode_length=max_episode_length, n_workers=n_workers) sampler = MultiprocessingSampler.from_worker_factory(workers, policies, envs=env) n_eps_per_worker = 3 episodes = sampler.obtain_exact_episodes(n_eps_per_worker, agent_update=policies) # At least one action per episode. assert sum(episodes.lengths) >= n_workers * n_eps_per_worker # All of the episodes. assert len(episodes.lengths) == n_workers * n_eps_per_worker worker = -1 for count, eps in enumerate(episodes.split()): if count % n_eps_per_worker == 0: worker += 1 assert (eps.actions == per_worker_actions[worker]).all() sampler.shutdown_worker() env.close()
def test_pickle(): max_episode_length = 16 env = PointEnv() policy = FixedPolicy(env.spec, scripted_actions=[ env.action_space.sample() for _ in range(max_episode_length) ]) tasks = SetTaskSampler(PointEnv) n_workers = 4 workers = WorkerFactory(seed=100, max_episode_length=max_episode_length, n_workers=n_workers) sampler = MultiprocessingSampler.from_worker_factory(workers, policy, env) sampler_pickled = pickle.dumps(sampler) sampler.shutdown_worker() sampler2 = pickle.loads(sampler_pickled) episodes = sampler2.obtain_samples(0, 500, np.asarray(policy.get_param_values()), env_update=tasks.sample(n_workers)) mean_rewards = [] goals = [] for eps in episodes.split(): mean_rewards.append(eps.rewards.mean()) goals.append(eps.env_infos['task'][0]['goal']) assert np.var(mean_rewards) > 0 assert np.var(goals) > 0 sampler2.shutdown_worker() env.close()
def test_init_with_env_updates(): max_episode_length = 16 env = PointEnv() policy = FixedPolicy(env.spec, scripted_actions=[ env.action_space.sample() for _ in range(max_episode_length) ]) tasks = SetTaskSampler(PointEnv) n_workers = 8 workers = WorkerFactory(seed=100, max_episode_length=max_episode_length, n_workers=n_workers) sampler = MultiprocessingSampler.from_worker_factory( workers, policy, envs=tasks.sample(n_workers)) episodes = sampler.obtain_samples(0, 160, policy) assert sum(episodes.lengths) >= 160 sampler.shutdown_worker() env.close()
def test_init_without_worker_factory(): max_episode_length = 16 env = PointEnv() policy = FixedPolicy(env.spec, scripted_actions=[ env.action_space.sample() for _ in range(max_episode_length) ]) sampler = MultiprocessingSampler(agents=policy, envs=env, seed=100, max_episode_length=max_episode_length) worker_factory = WorkerFactory(seed=100, max_episode_length=max_episode_length) assert sampler._factory._seed == worker_factory._seed assert (sampler._factory._max_episode_length == worker_factory._max_episode_length) with pytest.raises(TypeError, match='Must construct a sampler from'): MultiprocessingSampler(agents=policy, envs=env) sampler.shutdown_worker() env.close()