def test_lambda(): def add1(obs): return obs + 1 base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = observation_lambda_v0(base_env, add1) obs0 = env.reset() assert int(obs0[0][0][0]) == 1 env = observation_lambda_v0(env, add1) obs0 = env.reset() assert int(obs0[0][0][0]) == 2 def tile_obs(obs): shape_size = len(obs.shape) tile_shape = [1] * shape_size tile_shape[0] *= 2 return np.tile(obs, tile_shape) env = observation_lambda_v0(env, tile_obs) obs0 = env.reset() assert env.observation_space.shape == (16, 8, 3) def change_shape_fn(obs_space): return Box(low=0, high=1, shape=(32, 8, 3)) env = observation_lambda_v0(env, tile_obs) obs0 = env.reset() assert env.observation_space.shape == (32, 8, 3) assert obs0.shape == (32, 8, 3)
def hanabi_maker(): env = hanabi_v4.env() env = supersuit.observation_lambda_v0( env, lambda obs, obs_space: obs["observation"], lambda obs_space: obs_space["observation"], ) return env
def mahjong_maker(): env = mahjong_v4.env() env = supersuit.observation_lambda_v0( env, lambda obs, obs_space: obs["observation"], lambda obs_space: obs_space["observation"], ) return env
def env_fn(): env = AtariWrapper(gym.make("SpaceInvadersNoFrameskip-v4"), clip_reward=False) env = supersuit.frame_stack_v1(env, 4) env = supersuit.observation_lambda_v0( env, lambda obs: np.transpose(obs, axes=(2, 0, 1))) # env = supersuit.dtype_v0(env,np.float32) # env = supersuit.normalize_obs_v0(env) return env
def test_observation_lambda(): def add1(obs): return obs + 1 base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = observation_lambda_v0(base_env, add1) env.reset() obs0, _, _, _ = env.last() assert int(obs0[0][0][0]) == 1 env = observation_lambda_v0(env, add1) env.reset() obs0, _, _, _ = env.last() assert int(obs0[0][0][0]) == 2 def tile_obs(obs): shape_size = len(obs.shape) tile_shape = [1] * shape_size tile_shape[0] *= 2 return np.tile(obs, tile_shape) env = observation_lambda_v0(env, tile_obs) env.reset() obs0, _, _, _ = env.last() assert env.observation_spaces[env.agent_selection].shape == (16, 8, 3) def change_shape_fn(obs_space): return Box(low=0, high=1, shape=(32, 8, 3)) env = observation_lambda_v0(env, tile_obs) env.reset() obs0, _, _, _ = env.last() assert env.observation_spaces[env.agent_selection].shape == (32, 8, 3) assert obs0.shape == (32, 8, 3) base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = observation_lambda_v0(base_env, lambda obs, agent: obs + base_env.possible_agents.index(agent)) env.reset() obs0 = env.observe(env.agents[0]) obs1 = env.observe(env.agents[1]) assert int(obs0[0][0][0]) == 0 assert int(obs1[0][0][0]) == 2 assert (env.observation_spaces[env.agents[0]].high + 1 == env.observation_spaces[env.agents[1]].high).all() base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = observation_lambda_v0(base_env, lambda obs, agent: obs + base_env.possible_agents.index(agent), lambda obs_space, agent: Box(obs_space.low, obs_space.high + base_env.possible_agents.index(agent))) env.reset() obs0 = env.observe(env.agents[0]) obs1 = env.observe(env.agents[1]) assert int(obs0[0][0][0]) == 0 assert int(obs1[0][0][0]) == 2 assert (env.observation_spaces[env.agents[0]].high + 1 == env.observation_spaces[env.agents[1]].high).all()
def pad_observations_v0(env): assert isinstance(env, AECEnv) or isinstance( env, ParallelEnv ), "pad_observations_v0 only accepts an AECEnv or ParallelEnv" assert hasattr( env, "possible_agents" ), "environment passed to pad_observations must have a possible_agents list." spaces = [env.observation_space(agent) for agent in env.possible_agents] homogenize_ops.check_homogenize_spaces(spaces) padded_space = homogenize_ops.homogenize_spaces(spaces) return observation_lambda_v0( env, lambda obs, obs_space: homogenize_ops.homogenize_observations( padded_space, obs ), lambda obs_space: padded_space, )
def one_hot_obs_wrapper(env: AECEnv) -> AECEnv: """ :param env: env with observation space of Discrete(n) :return: wrapper env with observation as one-hot encoding """ def one_hot(x, n): v = np.zeros(n) v[x] = 1.0 return v max_obs_n = max( [obs_space.n for obs_space in env.observation_spaces.values()]) env = observation_lambda_v0( env, lambda obs: one_hot(obs, max_obs_n), lambda obs_space: gym.spaces.Box(low=np.full(obs_space.n, -np.inf), high=np.full(obs_space.n, np.inf)), ) return env
def agent_indicator_v0(env, type_only=False): assert isinstance(env, AECEnv) or isinstance( env, ParallelEnv ), "agent_indicator_v0 only accepts an AECEnv or ParallelEnv" assert hasattr( env, "possible_agents" ), "environment passed to agent indicator wrapper must have the possible_agents attribute." indicator_map = agent_ider.get_indicator_map(env.possible_agents, type_only) num_indicators = len(set(indicator_map.values())) obs_spaces = [env.observation_space(agent) for agent in env.possible_agents] agent_ider.check_params(obs_spaces) return observation_lambda_v0( env, lambda obs, obs_space, agent: agent_ider.change_observation( obs, obs_space, (indicator_map[agent], num_indicators), ), lambda obs_space: agent_ider.change_obs_space(obs_space, num_indicators), )