def test_update_envs_env_update(ray_local_session_fixture):
    del ray_local_session_fixture
    assert ray.is_initialized()
    max_episode_length = 16
    env = PointEnv()
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = RaySampler.from_worker_factory(workers, policy, env)
    episodes = sampler.obtain_samples(0,
                                      160,
                                      np.asarray(policy.get_param_values()),
                                      env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for eps in episodes.split():
        mean_rewards.append(eps.rewards.mean())
        goals.append(eps.env_infos['task'][0]['goal'])
    assert np.var(mean_rewards) > 0
    assert np.var(goals) > 0
    with pytest.raises(ValueError):
        sampler.obtain_samples(0,
                               10,
                               np.asarray(policy.get_param_values()),
                               env_update=tasks.sample(n_workers + 1))
示例#2
0
def test_update_envs_env_update():
    max_episode_length = 16
    env = PointEnv(max_episode_length=max_episode_length)
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = MultiprocessingSampler.from_worker_factory(workers, policy, env)
    episodes = sampler.obtain_samples(0,
                                      500,
                                      np.asarray(policy.get_param_values()),
                                      env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for eps in episodes.split():
        mean_rewards.append(eps.rewards.mean())
        goals.append(eps.env_infos['task'][0]['goal'])
    assert np.var(mean_rewards) > 1e-3
    assert np.var(goals) > 1e-3
    with pytest.raises(ValueError):
        sampler.obtain_samples(0,
                               10,
                               np.asarray(policy.get_param_values()),
                               env_update=tasks.sample(n_workers + 1))
    sampler.shutdown_worker()
    env.close()
示例#3
0
def test_update_envs_env_update():
    max_path_length = 16
    env = GarageEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers, policy, env)
    rollouts = sampler.obtain_samples(0,
                                      161,
                                      np.asarray(policy.get_param_values()),
                                      env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for rollout in rollouts.split():
        mean_rewards.append(rollout.rewards.mean())
        goals.append(rollout.env_infos['task'][0]['goal'])
    assert len(mean_rewards) == 11
    assert len(goals) == 11
    assert np.var(mean_rewards) > 1e-2
    assert np.var(goals) > 1e-2
    with pytest.raises(ValueError):
        sampler.obtain_samples(0,
                               10,
                               np.asarray(policy.get_param_values()),
                               env_update=tasks.sample(n_workers + 1))
def test_pickle():
    max_path_length = 16
    env = TfEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = MultiprocessingSampler.from_worker_factory(workers, policy, env)
    sampler_pickled = pickle.dumps(sampler)
    sampler.shutdown_worker()
    sampler2 = pickle.loads(sampler_pickled)
    rollouts = sampler2.obtain_samples(0,
                                       161,
                                       np.asarray(policy.get_param_values()),
                                       env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for rollout in rollouts.split():
        mean_rewards.append(rollout.rewards.mean())
        goals.append(rollout.env_infos['task'][0]['goal'])
    assert np.var(mean_rewards) > 0
    assert np.var(goals) > 0
    sampler2.shutdown_worker()
    env.close()
示例#5
0
def test_obtain_exact_trajectories(ray_local_session_fixture):
    del ray_local_session_fixture
    assert ray.is_initialized()
    max_path_length = 15
    n_workers = 8
    env = GarageEnv(PointEnv())
    per_worker_actions = [env.action_space.sample() for _ in range(n_workers)]
    policies = [
        FixedPolicy(env.spec, [action] * max_path_length)
        for action in per_worker_actions
    ]
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = RaySampler.from_worker_factory(workers, policies, envs=env)
    n_traj_per_worker = 3
    rollouts = sampler.obtain_exact_trajectories(n_traj_per_worker, policies)
    # At least one action per trajectory.
    assert sum(rollouts.lengths) >= n_workers * n_traj_per_worker
    # All of the trajectories.
    assert len(rollouts.lengths) == n_workers * n_traj_per_worker
    worker = -1
    for count, rollout in enumerate(rollouts.split()):
        if count % n_traj_per_worker == 0:
            worker += 1
        assert (rollout.actions == per_worker_actions[worker]).all()
示例#6
0
def test_obtain_exact_episodes():
    max_episode_length = 15
    n_workers = 8
    env = PointEnv()
    per_worker_actions = [env.action_space.sample() for _ in range(n_workers)]
    policies = [
        FixedPolicy(env.spec, [action] * max_episode_length)
        for action in per_worker_actions
    ]
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers, policies, envs=env)
    n_eps_per_worker = 3
    episodes = sampler.obtain_exact_episodes(n_eps_per_worker,
                                             agent_update=policies)
    # At least one action per episode.
    assert sum(episodes.lengths) >= n_workers * n_eps_per_worker
    # All of the episodes.
    assert len(episodes.lengths) == n_workers * n_eps_per_worker
    worker = -1
    for count, eps in enumerate(episodes.split()):
        if count % n_eps_per_worker == 0:
            worker += 1
        assert (eps.actions == per_worker_actions[worker]).all()
示例#7
0
def test_init_with_crashed_worker():
    max_episode_length = 16
    env = GarageEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(lambda: GarageEnv(PointEnv()))
    n_workers = 2
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)

    class CrashingPolicy:
        def reset(self, **kwargs):
            raise Exception('Intentional subprocess crash')

    bad_policy = CrashingPolicy()

    #  This causes worker 2 to crash.
    sampler = MultiprocessingSampler.from_worker_factory(
        workers, [policy, bad_policy], envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, None)
    assert sum(rollouts.lengths) >= 160
    sampler.shutdown_worker()
    env.close()
示例#8
0
def test_obtain_exact_trajectories():
    max_episode_length = 15
    n_workers = 8
    env = GarageEnv(PointEnv())
    per_worker_actions = [env.action_space.sample() for _ in range(n_workers)]
    policies = [
        FixedPolicy(env.spec, [action] * max_episode_length)
        for action in per_worker_actions
    ]
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = MultiprocessingSampler.from_worker_factory(workers,
                                                         policies,
                                                         envs=env)
    n_traj_per_worker = 3
    rollouts = sampler.obtain_exact_trajectories(n_traj_per_worker,
                                                 agent_update=policies)
    # At least one action per trajectory.
    assert sum(rollouts.lengths) >= n_workers * n_traj_per_worker
    # All of the trajectories.
    assert len(rollouts.lengths) == n_workers * n_traj_per_worker
    worker = -1
    for count, rollout in enumerate(rollouts.split()):
        if count % n_traj_per_worker == 0:
            worker += 1
        assert (rollout.actions == per_worker_actions[worker]).all()
    sampler.shutdown_worker()
    env.close()
示例#9
0
def test_get_actions():
    policy = FixedPolicy(None, np.array([1, 2, 3]))
    assert policy.get_actions(np.array([0]).reshape(1, 1))[0] == 1
    assert policy.get_action(np.array([0]))[0] == 2
    assert policy.get_action(np.array([0]))[0] == 3
    with pytest.raises(IndexError):
        policy.get_action(np.ndarray([0]))
示例#10
0
def test_no_seed():
    max_episode_length = 16
    env = PointEnv()
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    n_workers = 8
    workers = WorkerFactory(seed=None,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers, policy, env)
    episodes = sampler.obtain_samples(0, 160, policy)
    assert sum(episodes.lengths) >= 160
示例#11
0
def test_init_with_env_updates():
    max_episode_length = 16
    env = PointEnv()
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(PointEnv)
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = LocalSampler.from_worker_factory(workers,
                                               policy,
                                               envs=tasks.sample(n_workers))
    episodes = sampler.obtain_samples(0, 160, policy)
    assert sum(episodes.lengths) >= 160
示例#12
0
def test_init_with_env_updates():
    max_path_length = 16
    env = TfEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_path_length)
                         ])
    tasks = SetTaskSampler(lambda: TfEnv(PointEnv()))
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_path_length=max_path_length,
                            n_workers=n_workers)
    sampler = RaySampler.from_worker_factory(workers,
                                             policy,
                                             envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, policy)
    assert sum(rollouts.lengths) >= 160
示例#13
0
def test_init_without_worker_factory():
    max_episode_length = 16
    env = PointEnv()
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    sampler = LocalSampler(agents=policy,
                           envs=env,
                           seed=100,
                           max_episode_length=max_episode_length)
    worker_factory = WorkerFactory(seed=100,
                                   max_episode_length=max_episode_length)
    assert sampler._factory._seed == worker_factory._seed
    assert (sampler._factory._max_episode_length ==
            worker_factory._max_episode_length)
    with pytest.raises(TypeError, match='Must construct a sampler from'):
        LocalSampler(agents=policy, envs=env)
示例#14
0
def test_init_with_env_updates():
    max_episode_length = 16
    env = GarageEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(lambda: GarageEnv(PointEnv()))
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = MultiprocessingSampler.from_worker_factory(
        workers, policy, envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, policy)
    assert sum(rollouts.lengths) >= 160
    sampler.shutdown_worker()
    env.close()
示例#15
0
def test_init_with_env_updates(ray_local_session_fixture):
    del ray_local_session_fixture
    assert ray.is_initialized()
    max_episode_length = 16
    env = GarageEnv(PointEnv())
    policy = FixedPolicy(env.spec,
                         scripted_actions=[
                             env.action_space.sample()
                             for _ in range(max_episode_length)
                         ])
    tasks = SetTaskSampler(lambda: GarageEnv(PointEnv()))
    n_workers = 8
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers)
    sampler = RaySampler.from_worker_factory(workers,
                                             policy,
                                             envs=tasks.sample(n_workers))
    rollouts = sampler.obtain_samples(0, 160, policy)
    assert sum(rollouts.lengths) >= 160
示例#16
0
def test_update_envs_env_update(timesteps_per_call):
    max_episode_length = 16
    env = PointEnv()
    n_workers = 8
    policies = [
        FixedPolicy(env.spec,
                    scripted_actions=[
                        env.action_space.sample()
                        for _ in range(max_episode_length)
                    ]) for _ in range(n_workers)
    ]
    tasks = SetTaskSampler(PointEnv)
    workers = WorkerFactory(seed=100,
                            max_episode_length=max_episode_length,
                            n_workers=n_workers,
                            worker_class=FragmentWorker,
                            worker_args=dict(
                                n_envs=1,
                                timesteps_per_call=timesteps_per_call))
    sampler = LocalSampler.from_worker_factory(workers, policies, env)
    episodes = sampler.obtain_samples(0,
                                      160,
                                      None,
                                      env_update=tasks.sample(n_workers))
    mean_rewards = []
    goals = []
    for eps in episodes.split():
        mean_rewards.append(eps.rewards.mean())
        goals.append(eps.env_infos['task'][0]['goal'])
    assert len(mean_rewards) == int(160 / timesteps_per_call)
    assert len(goals) == int(160 / timesteps_per_call)
    assert np.var(mean_rewards) > 1e-2
    assert np.var(goals) > 1e-2
    with pytest.raises(ValueError):
        sampler.obtain_samples(0,
                               10,
                               None,
                               env_update=tasks.sample(n_workers + 1))
示例#17
0
def test_vectorization_multi_raises():
    policy = FixedPolicy(None, np.array([1, 2, 3]))
    with pytest.raises(ValueError):
        policy.reset([True, True])
    with pytest.raises(ValueError):
        policy.get_actions(np.array([0, 0]))