コード例 #1
0
def test_lambda():
    def add1(obs):
        return obs + 1

    base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
    env = observation_lambda_v0(base_env, add1)
    obs0 = env.reset()
    assert int(obs0[0][0][0]) == 1
    env = observation_lambda_v0(env, add1)
    obs0 = env.reset()
    assert int(obs0[0][0][0]) == 2

    def tile_obs(obs):
        shape_size = len(obs.shape)
        tile_shape = [1] * shape_size
        tile_shape[0] *= 2
        return np.tile(obs, tile_shape)

    env = observation_lambda_v0(env, tile_obs)
    obs0 = env.reset()
    assert env.observation_space.shape == (16, 8, 3)

    def change_shape_fn(obs_space):
        return Box(low=0, high=1, shape=(32, 8, 3))

    env = observation_lambda_v0(env, tile_obs)
    obs0 = env.reset()
    assert env.observation_space.shape == (32, 8, 3)
    assert obs0.shape == (32, 8, 3)
コード例 #2
0
def hanabi_maker():
    env = hanabi_v4.env()
    env = supersuit.observation_lambda_v0(
        env,
        lambda obs, obs_space: obs["observation"],
        lambda obs_space: obs_space["observation"],
    )
    return env
コード例 #3
0
def mahjong_maker():
    env = mahjong_v4.env()
    env = supersuit.observation_lambda_v0(
        env,
        lambda obs, obs_space: obs["observation"],
        lambda obs_space: obs_space["observation"],
    )
    return env
コード例 #4
0
def env_fn():
    env = AtariWrapper(gym.make("SpaceInvadersNoFrameskip-v4"),
                       clip_reward=False)
    env = supersuit.frame_stack_v1(env, 4)
    env = supersuit.observation_lambda_v0(
        env, lambda obs: np.transpose(obs, axes=(2, 0, 1)))
    # env = supersuit.dtype_v0(env,np.float32)
    # env = supersuit.normalize_obs_v0(env)
    return env
コード例 #5
0
ファイル: aec_mock_test.py プロジェクト: mimoralea/SuperSuit
def test_observation_lambda():
    def add1(obs):
        return obs + 1

    base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
    env = observation_lambda_v0(base_env, add1)
    env.reset()
    obs0, _, _, _ = env.last()
    assert int(obs0[0][0][0]) == 1
    env = observation_lambda_v0(env, add1)
    env.reset()
    obs0, _, _, _ = env.last()
    assert int(obs0[0][0][0]) == 2

    def tile_obs(obs):
        shape_size = len(obs.shape)
        tile_shape = [1] * shape_size
        tile_shape[0] *= 2
        return np.tile(obs, tile_shape)

    env = observation_lambda_v0(env, tile_obs)
    env.reset()
    obs0, _, _, _ = env.last()
    assert env.observation_spaces[env.agent_selection].shape == (16, 8, 3)

    def change_shape_fn(obs_space):
        return Box(low=0, high=1, shape=(32, 8, 3))

    env = observation_lambda_v0(env, tile_obs)
    env.reset()
    obs0, _, _, _ = env.last()
    assert env.observation_spaces[env.agent_selection].shape == (32, 8, 3)
    assert obs0.shape == (32, 8, 3)

    base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
    env = observation_lambda_v0(base_env, lambda obs, agent: obs + base_env.possible_agents.index(agent))
    env.reset()
    obs0 = env.observe(env.agents[0])
    obs1 = env.observe(env.agents[1])

    assert int(obs0[0][0][0]) == 0
    assert int(obs1[0][0][0]) == 2
    assert (env.observation_spaces[env.agents[0]].high + 1 == env.observation_spaces[env.agents[1]].high).all()

    base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
    env = observation_lambda_v0(base_env,
                                lambda obs, agent: obs + base_env.possible_agents.index(agent),
                                lambda obs_space, agent: Box(obs_space.low, obs_space.high + base_env.possible_agents.index(agent)))
    env.reset()
    obs0 = env.observe(env.agents[0])
    obs1 = env.observe(env.agents[1])

    assert int(obs0[0][0][0]) == 0
    assert int(obs1[0][0][0]) == 2
    assert (env.observation_spaces[env.agents[0]].high + 1 == env.observation_spaces[env.agents[1]].high).all()
コード例 #6
0
def pad_observations_v0(env):
    assert isinstance(env, AECEnv) or isinstance(
        env, ParallelEnv
    ), "pad_observations_v0 only accepts an AECEnv or ParallelEnv"
    assert hasattr(
        env, "possible_agents"
    ), "environment passed to pad_observations must have a possible_agents list."
    spaces = [env.observation_space(agent) for agent in env.possible_agents]
    homogenize_ops.check_homogenize_spaces(spaces)
    padded_space = homogenize_ops.homogenize_spaces(spaces)
    return observation_lambda_v0(
        env,
        lambda obs, obs_space: homogenize_ops.homogenize_observations(
            padded_space, obs
        ),
        lambda obs_space: padded_space,
    )
コード例 #7
0
def one_hot_obs_wrapper(env: AECEnv) -> AECEnv:
    """
    :param env: env with observation space of Discrete(n)
    :return: wrapper env with observation as one-hot encoding
    """
    def one_hot(x, n):
        v = np.zeros(n)
        v[x] = 1.0
        return v

    max_obs_n = max(
        [obs_space.n for obs_space in env.observation_spaces.values()])
    env = observation_lambda_v0(
        env,
        lambda obs: one_hot(obs, max_obs_n),
        lambda obs_space: gym.spaces.Box(low=np.full(obs_space.n, -np.inf),
                                         high=np.full(obs_space.n, np.inf)),
    )
    return env
コード例 #8
0
def agent_indicator_v0(env, type_only=False):
    assert isinstance(env, AECEnv) or isinstance(
        env, ParallelEnv
    ), "agent_indicator_v0 only accepts an AECEnv or ParallelEnv"
    assert hasattr(
        env, "possible_agents"
    ), "environment passed to agent indicator wrapper must have the possible_agents attribute."

    indicator_map = agent_ider.get_indicator_map(env.possible_agents, type_only)
    num_indicators = len(set(indicator_map.values()))

    obs_spaces = [env.observation_space(agent) for agent in env.possible_agents]
    agent_ider.check_params(obs_spaces)

    return observation_lambda_v0(
        env,
        lambda obs, obs_space, agent: agent_ider.change_observation(
            obs,
            obs_space,
            (indicator_map[agent], num_indicators),
        ),
        lambda obs_space: agent_ider.change_obs_space(obs_space, num_indicators),
    )