def __init__(self, mdp_info, clip_obs=10., alpha=1e-32):
        """
        Constructor.

        Args:
            mdp_info (MDPInfo): information of the MDP;
            clip_obs (float, 10.): values to clip the normalized observations;
            alpha (float, 1e-32): moving average catchup parameter for the
                normalization.

        """
        self.clip_obs = clip_obs
        self.obs_shape = mdp_info.observation_space.shape
        self.obs_runstand = RunningStandardization(shape=self.obs_shape,
                                                   alpha=alpha)
class StandardizationPreprocessor(object):
    """
    Preprocess observations from the environment using a running
    standardization.

    """
    def __init__(self, mdp_info, clip_obs=10., alpha=1e-32):
        """
        Constructor.

        Args:
            mdp_info (MDPInfo): information of the MDP;
            clip_obs (float, 10.): values to clip the normalized observations;
            alpha (float, 1e-32): moving average catchup parameter for the
                normalization.

        """
        self.clip_obs = clip_obs
        self.obs_shape = mdp_info.observation_space.shape
        self.obs_runstand = RunningStandardization(shape=self.obs_shape,
                                                   alpha=alpha)

    def __call__(self, obs):
        """
        Call function to normalize the observation.

        Args:
            obs (np.ndarray): observation to be normalized.

        Returns:
            Normalized observation array with the same shape.

        """
        assert obs.shape == self.obs_shape, \
            "Values given to running_norm have incorrect shape " \
            "(obs shape: {},  expected shape: {})" \
            .format(obs.shape, self.obs_shape)

        self.obs_runstand.update_stats(obs)
        norm_obs = np.clip(
            (obs - self.obs_runstand.mean) / self.obs_runstand.std,
            -self.clip_obs, self.clip_obs)

        return norm_obs

    def get_state(self):
        """
        Returns:
            A dictionary with the normalization state.

        """
        return self.obs_runstand.get_state()

    def set_state(self, data):
        """
        Set the current normalization state from the data dict.

        """
        self.obs_runstand.set_state(data)

    def save_state(self, path):
        """
        Save the running normalization state to path.

        Args:
            path (str): path to save the running normalization state.

        """
        with open(path, 'wb') as f:
            pickle.dump(self.get_state(), f, protocol=3)

    def load_state(self, path):
        """
        Load the running normalization state from path.

        Args:
            path (string): path to load the running normalization state from.

        """
        with open(path, 'rb') as f:
            data = pickle.load(f)
            self.set_state(data)
Exemple #3
0
class StandardizationPreprocessor(Serializable):
    """
    Preprocess observations from the environment using a running
    standardization.

    """
    def __init__(self, mdp_info, clip_obs=10., alpha=1e-32):
        """
        Constructor.

        Args:
            mdp_info (MDPInfo): information of the MDP;
            clip_obs (float, 10.): values to clip the normalized observations;
            alpha (float, 1e-32): moving average catchup parameter for the
                normalization.

        """
        self._clip_obs = clip_obs
        self._obs_shape = mdp_info.observation_space.shape
        self._obs_runstand = RunningStandardization(shape=self._obs_shape,
                                                    alpha=alpha)

        self._add_save_attr(_clip_obs='primitive',
                            _obs_shape='primitive',
                            _obs_runstand='mushroom')

    def __call__(self, obs):
        """
        Call function to normalize the observation.

        Args:
            obs (np.ndarray): observation to be normalized.

        Returns:
            Normalized observation array with the same shape.

        """
        assert obs.shape == self._obs_shape, \
            "Values given to running_norm have incorrect shape " \
            "(obs shape: {},  expected shape: {})" \
            .format(obs.shape, self._obs_shape)

        self._obs_runstand.update_stats(obs)
        norm_obs = np.clip(
            (obs - self._obs_runstand.mean) / self._obs_runstand.std,
            -self._clip_obs, self._clip_obs)

        return norm_obs

    def get_state(self):
        """
        Returns:
            A dictionary with the normalization state.

        """
        return self._obs_runstand.get_state()

    def set_state(self, data):
        """
        Set the current normalization state from the data dict.

        """
        self._obs_runstand.set_state(data)