def clone_wrapped(env): """ Clone Serializable AsaEnv wrapped into multiple other environment wrappers. This performs Serializable.clone on inner env and all its wrappers, if possible. Supported wrappers: - TfEnv: no cloning needed, only wrap in new instance - NormalizedEnv: clone and wrap - other Serializable: clone and wrap, display warning - other non-Serializable: wrap in new instance, display warning :param env: AsaEnv wrapped in multiple other environment wrappers """ # Unwrap stack = [] while not isinstance(env, AsaEnv): stack.append((type(env), env)) env = env.env # Clone inner env new_env = Serializable.clone(env) # Re-wrap, cloning wrappers on the way while stack: wrapper_cls, wrapper_env = stack.pop() if wrapper_cls is TfEnv: new_env = TfEnv(new_env) elif wrapper_cls is NormalizedEnv: # WARNING: obs_mean and obs_var are not copied to original env after skill training! new_env = Serializable.clone(wrapper_env, env=new_env) elif isinstance(wrapper_env, Serializable): new_env = Serializable.clone(wrapper_env, env=new_env) warn_once( 'AsaEnv: clone_wrapped performed on unknown Serializable wrapper "{}". ' 'Wrapper was cloned and applied.'.format(wrapper_cls)) else: new_env = wrapper_cls(env=new_env) warn_once( 'AsaEnv: clone_wrapped performed on unknown non-Serializable wrapper "{}". ' 'Wrapper was initiated with default parameters and applied.' .format(wrapper_cls)) return new_env
def create_new_skill(self, end_obss): """ Create new untrained skill and add it to skills list, along with its stopping function. :return: new skill policy and skill ID (index of the skill) :rtype: tuple(garage.policies.base.Policy, int) """ new_skill_id = len(self.skill_policies) new_skill_pol = Serializable.clone( obj=self.skill_policy_prototype, name='{}Skill{}'.format( type(self.skill_policy_prototype).__name__, new_skill_id)) self.skill_policies.append(new_skill_pol) self._skills_end_obss.append(np.copy(end_obss)) unique_end_obss = np.unique(self._skills_end_obss[new_skill_id], axis=0) self._skill_stop_functions.append( # lambda path: path['observations'][-1] in self._skills_end_obss[new_skill_id] lambda path: (path['observations'][-1] == unique_end_obss).all(axis=1).any()) return new_skill_pol, new_skill_id