コード例 #1
0
 def clone_wrapped(env):
     """
     Clone Serializable AsaEnv wrapped into multiple other environment wrappers.
     This performs Serializable.clone on inner env and all its wrappers, if possible.
     Supported wrappers:
     - TfEnv: no cloning needed, only wrap in new instance
     - NormalizedEnv: clone and wrap
     - other Serializable: clone and wrap, display warning
     - other non-Serializable: wrap in new instance, display warning
     :param env: AsaEnv wrapped in multiple other environment wrappers
     """
     # Unwrap
     stack = []
     while not isinstance(env, AsaEnv):
         stack.append((type(env), env))
         env = env.env
     # Clone inner env
     new_env = Serializable.clone(env)
     # Re-wrap, cloning wrappers on the way
     while stack:
         wrapper_cls, wrapper_env = stack.pop()
         if wrapper_cls is TfEnv:
             new_env = TfEnv(new_env)
         elif wrapper_cls is NormalizedEnv:
             # WARNING: obs_mean and obs_var are not copied to original env after skill training!
             new_env = Serializable.clone(wrapper_env, env=new_env)
         elif isinstance(wrapper_env, Serializable):
             new_env = Serializable.clone(wrapper_env, env=new_env)
             warn_once(
                 'AsaEnv: clone_wrapped performed on unknown Serializable wrapper "{}". '
                 'Wrapper was cloned and applied.'.format(wrapper_cls))
         else:
             new_env = wrapper_cls(env=new_env)
             warn_once(
                 'AsaEnv: clone_wrapped performed on unknown non-Serializable wrapper "{}". '
                 'Wrapper was initiated with default parameters and applied.'
                 .format(wrapper_cls))
     return new_env
コード例 #2
0
    def create_new_skill(self, end_obss):
        """
        Create new untrained skill and add it to skills list, along with its stopping function.
        :return: new skill policy and skill ID (index of the skill)
        :rtype: tuple(garage.policies.base.Policy, int)
        """
        new_skill_id = len(self.skill_policies)
        new_skill_pol = Serializable.clone(
            obj=self.skill_policy_prototype,
            name='{}Skill{}'.format(
                type(self.skill_policy_prototype).__name__, new_skill_id))
        self.skill_policies.append(new_skill_pol)
        self._skills_end_obss.append(np.copy(end_obss))

        unique_end_obss = np.unique(self._skills_end_obss[new_skill_id],
                                    axis=0)
        self._skill_stop_functions.append(
            # lambda path: path['observations'][-1] in self._skills_end_obss[new_skill_id]
            lambda path:
            (path['observations'][-1] == unique_end_obss).all(axis=1).any())
        return new_skill_pol, new_skill_id