Ejemplo n.º 1
0
    def make_env(i, this_seed):
        # Previously, we directly called `gym.make(env_name)`, but running
        # `imitation.scripts.train_adversarial` within `imitation.scripts.parallel`
        # created a weird interaction between Gym and Ray -- `gym.make` would fail
        # inside this function for any of our custom environment unless those
        # environments were also `gym.register()`ed inside `make_env`. Even
        # registering the custom environment in the scope of `make_vec_env` didn't
        # work. For more discussion and hypotheses on this issue see PR #160:
        # https://github.com/HumanCompatibleAI/imitation/pull/160.
        env = spec.make()

        # Seed each environment with a different, non-sequential seed for diversity
        # (even if caller is passing us sequentially-assigned base seeds). int() is
        # necessary to work around gym bug where it chokes on numpy int64s.
        env.seed(int(this_seed))

        if max_episode_steps is not None:
            env = TimeLimit(env, max_episode_steps)
        elif spec.max_episode_steps is not None:
            env = TimeLimit(env, max_episode_steps=spec.max_episode_steps)

        # Use Monitor to record statistics needed for Baselines algorithms logging
        # Optionally, save to disk
        log_path = None
        if log_dir is not None:
            log_subdir = os.path.join(log_dir, "monitor")
            os.makedirs(log_subdir, exist_ok=True)
            log_path = os.path.join(log_subdir, f"mon{i:03d}")

        env = monitor.Monitor(env, log_path)
        env = wrappers.RolloutInfoWrapper(env)
        return env
Ejemplo n.º 2
0
def test_unwrap_traj():
    """Check that unwrap_traj reverses `ObsRewIncrementWrapper`.

    Also check that unwrapping twice is a no-op.
    """
    env = gym.make("CartPole-v1")
    env = wrappers.RolloutInfoWrapper(env)
    env = ObsRewHalveWrapper(env)
    venv = vec_env.DummyVecEnv([lambda: env])

    with serialize.load_policy("zero", "UNUSED", venv) as policy:
        trajs = rollout.generate_trajectories(policy, venv, rollout.min_episodes(10))
    trajs_unwrapped = [rollout.unwrap_traj(t) for t in trajs]
    trajs_unwrapped_twice = [rollout.unwrap_traj(t) for t in trajs_unwrapped]

    for t, t_unwrapped in zip(trajs, trajs_unwrapped):
        np.testing.assert_allclose(t.acts, t_unwrapped.acts)
        np.testing.assert_allclose(t.obs, t_unwrapped.obs / 2)
        np.testing.assert_allclose(t.rews, t_unwrapped.rews / 2)

    for t1, t2 in zip(trajs_unwrapped, trajs_unwrapped_twice):
        np.testing.assert_equal(t1.acts, t2.acts)
        np.testing.assert_equal(t1.obs, t2.obs)
        np.testing.assert_equal(t1.rews, t2.rews)
Ejemplo n.º 3
0
def create_env(env):
    env = gym.make(env)
    env = wrappers.RolloutInfoWrapper(env)
    return env