class ObserverMinigrid(Wrapper):
    def __init__(self, env):
        super(ObserverMinigrid, self).__init__(env)
        self.env = FlatObsWrapper(env)

    def reset(self):
        obs = self.env.reset()
        return ([obs["image"]], np.asarray([obs["direction"]]))

    def step(self, action):
        obs, r, done, info = self.env.step(action)
        return obs, r, done, info
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        elif 'Mini' in env_id:
            import gym_minigrid
            env = gym_minigrid.envs.dynamicobstacles.DynamicObstaclesEnv(
                size=5, n_obstacles=1)
            # env = gym_minigrid.envs.multiroom.MultiRoomEnv(
            #     minNumRooms=2, maxNumRooms=2, maxRoomSize=4
            # )
            # import pdb; pdb.set_trace()
            # env = gym.make(env_id)
            # env = gym_minigrid.wrappers.NoOpAsync(env, costs=[2, 1])
            env = gym_minigrid.wrappers.NoOpAsync(env,
                                                  costs=[2, 1],
                                                  which_obs='first')

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)

        is_minigrid = 'minigrid' in env_id.lower()

        env.seed(seed + rank)

        if is_minigrid:
            from gym_minigrid.wrappers import ImgObsWrapper, RGBImgObsWrapper, RGBImgPartialObsWrapper, FlatObsWrapper
            # env = RGBImgPartialObsWrapper(
            #     env, tile_size=2)
            # env = ImgObsWrapper(env)
            env = FlatObsWrapper(env)
            # env.observation_space = env.observation_space['image']

        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(env,
                                os.path.join(log_dir, str(rank)),
                                allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env)
        elif is_minigrid:
            pass
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env
Exemple #3
0
            if reward_mean > 500:
                break

    def play(self, num_episodes, render=True):
        """Test the trained agent.
        """
        for episode in range(num_episodes):
            state = self.env.reset()
            total_reward = 0.0
            while True:
                if render:
                    self.env.render()
                action = self.get_action(state)
                state, reward, done, _ = self.env.step(action)
                total_reward += reward
                if done:
                    print(
                        f"Total reward: {total_reward} in episode {episode + 1}"
                    )
                    break


if __name__ == "__main__":
    env = gym.make("MiniGrid-Empty-8x8-v0")
    env = FlatObsWrapper(env)
    agent = Agent(env)
    print("Number of observations: ", agent.observations)
    print("Number of actions: ", agent.actions)
    agent.train(percentile=99.9, num_iterations=64, num_episodes=128)
    agent.play(num_episodes=3)
def make_env(name="MiniGrid-Empty-5x5-v0"):
    env = gym.make(name)
    env = FlatObsWrapper(env)
    return env
 def __init__(self, env):
     super(ObserverMinigrid, self).__init__(env)
     self.env = FlatObsWrapper(env)
 def mini_grid_wrap(env, _):
     if issubclass(env.__class__, MiniGridEnv): env = FlatObsWrapper(env)
     return env