コード例 #1
0
ファイル: env_checker.py プロジェクト: jeehyun100/control_rl
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
    """
    Check the returned values by the env when calling `.reset()` or `.step()` methods.
    """
    # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
    obs = env.reset()

    _check_obs(obs, observation_space, 'reset')

    # Sample a random action
    action = action_space.sample()
    data = env.step(action)

    assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info"

    # Unpack
    obs, reward, done, info = data

    _check_obs(obs, observation_space, 'step')

    # We also allow int because the reward will be cast to float
    assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
    assert isinstance(done, bool), "The `done` signal must be a boolean"
    assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"

    if isinstance(env, gym.GoalEnv):
        # For a GoalEnv, the keys are checked at reset
        assert reward == env.compute_reward(obs['achieved_goal'], obs['desired_goal'], info)
コード例 #2
0
def _check_returned_values(env: gym.Env, observation_space: spaces.Space,
                           action_space: spaces.Space) -> None:
    """
    Check the returned values by the env when calling `.reset()` or `.step()` methods.
    """
    # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
    obs = env.reset()

    if isinstance(observation_space, spaces.Dict):
        assert isinstance(
            obs,
            dict), "The observation returned by `reset()` must be a dictionary"
        for key in observation_space.spaces.keys():
            try:
                _check_obs(obs[key], observation_space.spaces[key], "reset")
            except AssertionError as e:
                raise AssertionError(f"Error while checking key={key}: " +
                                     str(e))
    else:
        _check_obs(obs, observation_space, "reset")

    # Sample a random action
    action = action_space.sample()
    data = env.step(action)

    assert (
        len(data) == 4
    ), "The `step()` method must return four values: obs, reward, done, info"

    # Unpack
    obs, reward, done, info = data

    if isinstance(observation_space, spaces.Dict):
        assert isinstance(
            obs,
            dict), "The observation returned by `step()` must be a dictionary"
        for key in observation_space.spaces.keys():
            try:
                _check_obs(obs[key], observation_space.spaces[key], "step")
            except AssertionError as e:
                raise AssertionError(f"Error while checking key={key}: " +
                                     str(e))

    else:
        _check_obs(obs, observation_space, "step")

    # We also allow int because the reward will be cast to float
    assert isinstance(
        reward,
        (float, int,
         np.float32)), "The reward returned by `step()` must be a float"
    assert isinstance(done, bool), "The `done` signal must be a boolean"
    assert isinstance(
        info,
        dict), "The `info` returned by `step()` must be a python dictionary"

    if isinstance(env, gym.GoalEnv):
        # For a GoalEnv, the keys are checked at reset
        assert reward == env.compute_reward(obs["achieved_goal"],
                                            obs["desired_goal"], info)
コード例 #3
0
def _check_returned_values(env: gym.Env, observation_space: Space, action_space: Space):
    """Check the returned values by the env when calling :meth:`env.reset` or :meth:`env.step` methods.

    Args:
        env: The environment
        observation_space: The environment's observation space
        action_space: The environment's action space
    """
    # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
    obs = env.reset()

    if isinstance(observation_space, Dict):
        assert isinstance(
            obs, dict
        ), "The observation returned by `reset()` must be a dictionary"
        for key in observation_space.spaces.keys():
            try:
                _check_obs(obs[key], observation_space.spaces[key], "reset")
            except AssertionError as e:
                raise AssertionError(f"Error while checking key={key}: " + str(e))
    else:
        _check_obs(obs, observation_space, "reset")

    # Sample a random action
    action = action_space.sample()
    data = env.step(action)

    assert (
        len(data) == 4
    ), "The `step()` method must return four values: obs, reward, done, info"

    # Unpack
    obs, reward, done, info = data

    if isinstance(observation_space, Dict):
        assert isinstance(
            obs, dict
        ), "The observation returned by `step()` must be a dictionary"
        for key in observation_space.spaces.keys():
            try:
                _check_obs(obs[key], observation_space.spaces[key], "step")
            except AssertionError as e:
                raise AssertionError(f"Error while checking key={key}: " + str(e))

    else:
        _check_obs(obs, observation_space, "step")

    # We also allow int because the reward will be cast to float
    assert isinstance(
        reward, (float, int, np.float32)
    ), "The reward returned by `step()` must be a float"
    assert isinstance(done, bool), "The `done` signal must be a boolean"
    assert isinstance(
        info, dict
    ), "The `info` returned by `step()` must be a python dictionary"
コード例 #4
0
ファイル: helpers.py プロジェクト: mtcrawshaw/meta
def get_obs_batch(batch_size: int, obs_space: Space,
                  num_tasks: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a batch of (multi-task) observations and task indices. Note that `obs_space`
    must be one-dimensional.
    """

    obs_shape = obs_space.sample().shape
    assert len(obs_shape) == 1
    obs_len = obs_shape[0]

    obs_list = []
    for i in range(batch_size):
        ob = torch.Tensor(obs_space.sample())
        task_vector = one_hot_tensor(num_tasks)
        obs_list.append(torch.cat([ob, task_vector]))
    obs = torch.stack(obs_list)
    nonzero_pos = obs[:, obs_len:].nonzero()
    assert nonzero_pos[:, 0].tolist() == list(range(batch_size))
    task_indices = nonzero_pos[:, 1]

    return obs, task_indices
コード例 #5
0
def check_run(env: gym.Env, action_space: spaces.Space):
    """Check normally running process of webotenv."""
    num_env = 3
    time_steps = 100
    for _ in range(num_env):
        env.reset()
        for j in range(time_steps):
            action = action_space.sample()
            _, _, done, _ = env.step(action)
            if done is True:
                assert j+1 == env.steps_in_run, \
                         "The value of time steps is correct when 'done'"
                break
            if j == time_steps - 1:
                assert env.steps_in_run == time_steps, \
                    "The number time steps are correct after steps > 1"
コード例 #6
0
def check_reset_step(env: gym.Env, observation_space: spaces.Space,
                     action_space: spaces.Space):
    """ Check reset() and step() function."""
    obs_pre = env.reset()
    _check_obs(obs_pre, observation_space, 'reset')
    obs_current = env.reset()
    assert (obs_pre[1:8] != obs_current[1:8]).any(), \
        "The infos of the observation must differ after reset the webot env."
    assert (obs_pre[10:] != obs_current[10:]).any(), \
        "The infos of the lidar data must differ after reset the webot env."

    for _ in range(3):
        action = action_space.sample()
        obs_next, _, _, _ = env.step(action)
        _check_obs(obs_next, observation_space, 'step')
    assert (obs_next[0:3] != obs_current[0:3]).any(), \
        "The information of observation must be updated after the first action"
    assert (obs_next[6:9] != obs_current[6:9]).any(), \
        "The information of observation must be updated after the first action"
    assert (obs_next[10:] != obs_current[10:]).any(), \
        "The information of lidar data must be updated after the first action"