def make_env(env: gym.Env, length: int = 1, action_limits: Tuple[int, int] = (-1.0, 1.0), image_size: Tuple[int, int] = (64, 64), nchw_format: bool = True): env = wrappers.RescaleAction(env, *action_limits) env = ConvertImage(env, image_size, nchw_format, dict_key='image') env = FrameStack(env, length, nchw_format) return env
def __init__(self, domain, task, *args, env=None, rescale_action_range=(-1.0, 1.0), rescale_observation_range=None, observation_keys=(), goal_keys=(), unwrap_time_limit=True, pixel_wrapper_kwargs=None, **kwargs): assert not args, ( "Gym environments don't support args. Use kwargs instead.") self.rescale_action_range = rescale_action_range self.rescale_observation_range = rescale_observation_range self.unwrap_time_limit = unwrap_time_limit super(GymAdapter, self).__init__(domain, task, *args, goal_keys=goal_keys, **kwargs) if env is None: assert (domain is not None and task is not None), (domain, task) try: env_id = f"{domain}-{task}" env = gym.envs.make(env_id, **kwargs) except gym.error.UnregisteredEnv: env_id = f"{domain}{task}" env = gym.envs.make(env_id, **kwargs) self._env_kwargs = kwargs else: assert not kwargs assert domain is None and task is None, (domain, task) if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit: # Remove the TimeLimit wrapper that sets 'done = True' when # the time limit specified for each environment has been passed and # therefore the environment is not Markovian (terminal condition # depends on time rather than state). env = env.env if rescale_observation_range: env = RescaleObservation(env, *rescale_observation_range) if rescale_action_range and is_continuous_space(env.action_space): env = wrappers.RescaleAction(env, *rescale_action_range) # TODO(hartikainen): We need the clip action wrapper because sometimes # the tfp.bijectors.Tanh() produces values strictly greater than 1 or # strictly less than -1, which causes the env fail without clipping. # The error is in the order of 1e-7, which should not cause issues. # See https://github.com/tensorflow/probability/issues/664. env = wrappers.ClipAction(env) if pixel_wrapper_kwargs is not None: env = wrappers.PixelObservationWrapper(env, **pixel_wrapper_kwargs) self._env = env if isinstance(self._env.observation_space, spaces.Dict): dict_observation_space = self._env.observation_space self.observation_keys = (observation_keys or (*self.observation_space.spaces.keys(), )) elif isinstance(self._env.observation_space, spaces.Box): dict_observation_space = spaces.Dict( OrderedDict(((DEFAULT_OBSERVATION_KEY, self._env.observation_space), ))) self.observation_keys = (DEFAULT_OBSERVATION_KEY, ) self._observation_space = type(dict_observation_space)([ (name, copy.deepcopy(space)) for name, space in dict_observation_space.spaces.items() if name in self.observation_keys + self.goal_keys ]) if len(self._env.action_space.shape) > 1: raise NotImplementedError( "Shape of the action space ({}) is not flat, make sure to" " check the implemenation.".format(self._env.action_space)) self._action_space = self._env.action_space