def __init__(self, mdp_info, clip_obs=10., alpha=1e-32): """ Constructor. Args: mdp_info (MDPInfo): information of the MDP; clip_obs (float, 10.): values to clip the normalized observations; alpha (float, 1e-32): moving average catchup parameter for the normalization. """ self.clip_obs = clip_obs self.obs_shape = mdp_info.observation_space.shape self.obs_runstand = RunningStandardization(shape=self.obs_shape, alpha=alpha)
class StandardizationPreprocessor(object): """ Preprocess observations from the environment using a running standardization. """ def __init__(self, mdp_info, clip_obs=10., alpha=1e-32): """ Constructor. Args: mdp_info (MDPInfo): information of the MDP; clip_obs (float, 10.): values to clip the normalized observations; alpha (float, 1e-32): moving average catchup parameter for the normalization. """ self.clip_obs = clip_obs self.obs_shape = mdp_info.observation_space.shape self.obs_runstand = RunningStandardization(shape=self.obs_shape, alpha=alpha) def __call__(self, obs): """ Call function to normalize the observation. Args: obs (np.ndarray): observation to be normalized. Returns: Normalized observation array with the same shape. """ assert obs.shape == self.obs_shape, \ "Values given to running_norm have incorrect shape " \ "(obs shape: {}, expected shape: {})" \ .format(obs.shape, self.obs_shape) self.obs_runstand.update_stats(obs) norm_obs = np.clip( (obs - self.obs_runstand.mean) / self.obs_runstand.std, -self.clip_obs, self.clip_obs) return norm_obs def get_state(self): """ Returns: A dictionary with the normalization state. """ return self.obs_runstand.get_state() def set_state(self, data): """ Set the current normalization state from the data dict. """ self.obs_runstand.set_state(data) def save_state(self, path): """ Save the running normalization state to path. Args: path (str): path to save the running normalization state. """ with open(path, 'wb') as f: pickle.dump(self.get_state(), f, protocol=3) def load_state(self, path): """ Load the running normalization state from path. Args: path (string): path to load the running normalization state from. """ with open(path, 'rb') as f: data = pickle.load(f) self.set_state(data)
class StandardizationPreprocessor(Serializable): """ Preprocess observations from the environment using a running standardization. """ def __init__(self, mdp_info, clip_obs=10., alpha=1e-32): """ Constructor. Args: mdp_info (MDPInfo): information of the MDP; clip_obs (float, 10.): values to clip the normalized observations; alpha (float, 1e-32): moving average catchup parameter for the normalization. """ self._clip_obs = clip_obs self._obs_shape = mdp_info.observation_space.shape self._obs_runstand = RunningStandardization(shape=self._obs_shape, alpha=alpha) self._add_save_attr(_clip_obs='primitive', _obs_shape='primitive', _obs_runstand='mushroom') def __call__(self, obs): """ Call function to normalize the observation. Args: obs (np.ndarray): observation to be normalized. Returns: Normalized observation array with the same shape. """ assert obs.shape == self._obs_shape, \ "Values given to running_norm have incorrect shape " \ "(obs shape: {}, expected shape: {})" \ .format(obs.shape, self._obs_shape) self._obs_runstand.update_stats(obs) norm_obs = np.clip( (obs - self._obs_runstand.mean) / self._obs_runstand.std, -self._clip_obs, self._clip_obs) return norm_obs def get_state(self): """ Returns: A dictionary with the normalization state. """ return self._obs_runstand.get_state() def set_state(self, data): """ Set the current normalization state from the data dict. """ self._obs_runstand.set_state(data)