def __init__(self,
                 domain,
                 task,
                 *args,
                 normalize=True,
                 observation_keys=None,
                 unwrap_time_limit=True,
                 **kwargs):
        self.normalize = normalize
        self.observation_keys = observation_keys
        self.unwrap_time_limit = unwrap_time_limit

        self._Serializable__initialize(locals())
        super(GymAdapter, self).__init__(domain, task, *args, **kwargs)

        env = GYM_ENVIRONMENTS[domain][task](*args, **kwargs)

        if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
            # Remove the TimeLimit wrapper that sets 'done = True' when
            # the time limit specified for each environment has been passed and
            # therefore the environment is not Markovian (terminal condition
            # depends on time rather than state).
            env = env.env

        if isinstance(env.observation_space, spaces.Dict):
            observation_keys = (observation_keys
                                or list(env.observation_space.spaces.keys()))
        if normalize:
            env = NormalizeActionWrapper(env)

        self._env = env
Beispiel #2
0
    def __init__(self,
                 domain,
                 task,
                 *args,
                 env=None,
                 normalize=True,
                 observation_keys=None,
                 unwrap_time_limit=True,
                 **kwargs):
        assert not args, (
            "Gym environments don't support args. Use kwargs instead.")

        self.normalize = normalize
        self.observation_keys = observation_keys
        self.unwrap_time_limit = unwrap_time_limit

        super(GymAdapter, self).__init__(domain, task, *args, **kwargs)

        if env is None:
            assert (domain is not None and task is not None), (domain, task)
            env_id = "{}-{}".format(domain, task)
            env = gym.envs.make(env_id, **kwargs)
        else:
            assert domain is None and task is None, (domain, task)

        if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
            # Remove the TimeLimit wrapper that sets 'done = True' when
            # the time limit specified for each environment has been passed and
            # therefore the environment is not Markovian (terminal condition
            # depends on time rather than state).
            env = env.env

        if normalize:
            env = NormalizeActionWrapper(env)

        self._env = env

        if isinstance(self._env.observation_space, spaces.Dict):
            dict_observation_space = self._env.observation_space
            self.observation_keys = (
                observation_keys
                or (*self._env.observation_space.spaces.keys(), ))
        elif isinstance(self._env.observation_space, spaces.Box):
            dict_observation_space = spaces.Dict(
                OrderedDict(((DEFAULT_OBSERVATION_KEY,
                              self._env.observation_space), )))
            self.observation_keys = (DEFAULT_OBSERVATION_KEY, )

        self._observation_space = type(dict_observation_space)([
            (name, copy.deepcopy(space))
            for name, space in dict_observation_space.spaces.items()
            if name in self.observation_keys
        ])
Beispiel #3
0
    def __init__(self,
                 domain,
                 task,
                 *args,
                 env=None,
                 normalize=True,
                 observation_keys=None,
                 unwrap_time_limit=True,
                 **kwargs):
        assert not args, (
            "Gym environments don't support args. Use kwargs instead.")

        self._Serializable__initialize(locals())

        self.normalize = normalize
        self.unwrap_time_limit = unwrap_time_limit

        super(GymAdapter, self).__init__(domain, task, *args, **kwargs)

        if env is None:
            assert (domain is not None and task is not None), (domain, task)
            env_id = f"{domain}-{task}"
            env = gym.envs.make(env_id, **kwargs)
        else:
            assert domain is None and task is None, (domain, task)

        if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
            # Remove the TimeLimit wrapper that sets 'done = True' when
            # the time limit specified for each environment has been passed and
            # therefore the environment is not Markovian (terminal condition
            # depends on time rather than state).
            env = env.env

        if isinstance(env.observation_space, spaces.Dict):
            observation_keys = (observation_keys
                                or tuple(env.observation_space.spaces.keys()))

        self.observation_keys = observation_keys

        if normalize:
            env = NormalizeActionWrapper(env)

        self._env = env
Beispiel #4
0
    def __init__(self,
                 domain,
                 task,
                 *args,
                 env=None,
                 normalize=True,
                 observation_keys=None,
                 unwrap_time_limit=True,
                 **kwargs):
        assert not args, (
            "Gym environments don't support args. Use kwargs instead.")

        self.normalize = normalize
        self.observation_keys = observation_keys
        self.unwrap_time_limit = unwrap_time_limit
        self.stacks = 1
        self.stacking_axis = 0

        self._Serializable__initialize(locals())
        super(GymAdapter, self).__init__(domain, task, *args, **kwargs)

        if env is None:
            assert (domain is not None and task is not None), (domain, task)
            env_id = f"{domain}-{task}"
            env = gym.envs.make(env_id, **kwargs)

            #env_id = f""
            #env = gym.make("Safexp-PointGoal1-v0")
            
        else:
            assert domain is None and task is None, (domain, task)
            env_id = 'custom'

        if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
            # Remove the TimeLimit wrapper that sets 'done = True' when
            # the time limit specified for each environment has been passed and
            # therefore the environment is not Markovian (terminal condition
            # depends on time rather than state).
            env = env.env

        if isinstance(env.observation_space, spaces.Dict):
            observation_keys = (
                observation_keys or list(env.observation_space.spaces.keys()))
        if normalize:
            env = NormalizeActionWrapper(env)


        #### --- specifically for safety_gym wrappring --- ###
        if env_id in SAFETY_WRAPPER_IDS:
            dirname, _ = os.path.split(os.path.abspath(__file__))
            #### load config file
            with open(f'{dirname}/../gym/safety_gym/configs/{env_id}_config.json', 'r') as fp:
                config = json.load(fp)
            fp.close()
            # with open(f'{dirname}/../gym/safety_gym/add_configs/{env_id}_add_config.json', 'r') as fp:
            #     add_config = json.load(fp)
            # fp.close()


            env = Engine(config)
            # env = SAFETY_WRAPPER_IDS[env_id](env)

            #### additional config info like stacking etc.
            # for k in add_config.keys():
            #     self.safeconfig[k] = add_config[k]
                    
            #### dump config file to current data dir
            with open(f'{env_id}_config.json', 'w') as fp:
                json.dump(config, fp)
            fp.close()
            ####

            ### adding unserializable additional info after dumping (lol)
            # self.obs_indices = env.obs_indices
            # self.safeconfig['obs_indices'] = self.obs_indices

            ### stack env
            self.stacks = config.get('stacks', 1) ### for convenience
            self.stacking_axis = config.get('stacking_axis',0)
            if self.stacks>1:
                env = DummyVecEnv([lambda:env])
                #env = VecNormalize(env)        doesn't work at all for some reason
                env = VecFrameStack(env, self.stacks)

        #### --- end specifically for safety_gym  --- ###


        self._env = env
    def __init__(self,
                 domain,
                 task,
                 *args,
                 env=None,
                 normalize=True,
                 observation_keys=(),
                 goal_keys=(),
                 unwrap_time_limit=True,
                 pixel_wrapper_kwargs=None,
                 **kwargs):
        assert not args, (
            "Gym environments don't support args. Use kwargs instead.")

        self.normalize = normalize
        self.unwrap_time_limit = unwrap_time_limit

        super(GymAdapter, self).__init__(domain,
                                         task,
                                         *args,
                                         goal_keys=goal_keys,
                                         **kwargs)

        if env is None:
            assert (domain is not None and task is not None), (domain, task)
            env_id = f"{domain}-{task}"
            env = gym.envs.make(env_id, **kwargs)
            self._env_kwargs = kwargs
        else:
            assert not kwargs
            assert domain is None and task is None, (domain, task)

        if isinstance(env, wrappers.TimeLimit) and unwrap_time_limit:
            # Remove the TimeLimit wrapper that sets 'done = True' when
            # the time limit specified for each environment has been passed and
            # therefore the environment is not Markovian (terminal condition
            # depends on time rather than state).
            env = env.env

        if normalize:
            env = NormalizeActionWrapper(env)

        if pixel_wrapper_kwargs is not None:
            env = PixelObservationWrapper(env, **pixel_wrapper_kwargs)

        self._env = env

        if isinstance(self._env.observation_space, spaces.Dict):
            dict_observation_space = self._env.observation_space
            self.observation_keys = (observation_keys or
                                     (*self.observation_space.spaces.keys(), ))
        elif isinstance(self._env.observation_space, spaces.Box):
            dict_observation_space = spaces.Dict(
                OrderedDict(((DEFAULT_OBSERVATION_KEY,
                              self._env.observation_space), )))
            self.observation_keys = (DEFAULT_OBSERVATION_KEY, )

        self._observation_space = type(dict_observation_space)([
            (name, copy.deepcopy(space))
            for name, space in dict_observation_space.spaces.items()
            if name in self.observation_keys
        ])

        if len(self._env.action_space.shape) > 1:
            raise NotImplementedError(
                "Shape of the action space ({}) is not flat, make sure to"
                " check the implemenation.".format(self._env.action_space))

        self._action_space = self._env.action_space