def test_std_share_network_shapes(self, output_dim, hidden_dim):
        model = GaussianGRUModel(output_dim=output_dim,
                                 hidden_dim=hidden_dim,
                                 std_share_network=True,
                                 hidden_nonlinearity=None,
                                 recurrent_nonlinearity=None,
                                 hidden_w_init=self.default_initializer,
                                 recurrent_w_init=self.default_initializer,
                                 output_w_init=self.default_initializer)
        step_hidden_var = tf.compat.v1.placeholder(shape=(self.batch_size,
                                                          hidden_dim),
                                                   name='step_hidden',
                                                   dtype=tf.float32)
        (mean_var, step_mean_var, log_std_var, step_log_std_var, step_hidden,
         hidden_init_var, dist) = model.build(self.input_var,
                                              self.step_input_var,
                                              step_hidden_var)

        # output layer is a tf.keras.layers.Dense object,
        # which cannot be access by tf.compat.v1.variable_scope.
        # A workaround is to access in tf.compat.v1.global_variables()
        for var in tf.compat.v1.global_variables():
            if 'output_layer/kernel' in var.name:
                std_share_output_weights = var
            if 'output_layer/bias' in var.name:
                std_share_output_bias = var
        assert std_share_output_weights.shape[1] == output_dim * 2
        assert std_share_output_bias.shape == output_dim * 2
Esempio n. 2
0
 def test_dist(self):
     model = GaussianGRUModel(output_dim=1, hidden_dim=1)
     step_hidden_var = tf.compat.v1.placeholder(shape=(self.batch_size, 1),
                                                name='step_hidden',
                                                dtype=tf.float32)
     model.build(self.input_var, self.step_input_var, step_hidden_var)
     assert isinstance(model.networks['default'].dist,
                       tfp.distributions.MultivariateNormalDiag)
Esempio n. 3
0
    def __init__(self,
                 env_spec,
                 hidden_dims=[32],
                 name='GaussianGRUPolicy',
                 hidden_nonlinearity=tf.nn.tanh,
                 hidden_w_init=tf.glorot_uniform_initializer(),
                 hidden_b_init=tf.zeros_initializer(),
                 recurrent_nonlinearity=tf.nn.sigmoid,
                 recurrent_w_init=tf.glorot_uniform_initializer(),
                 output_nonlinearity=None,
                 output_w_init=tf.glorot_uniform_initializer(),
                 output_b_init=tf.zeros_initializer(),
                 hidden_state_init=tf.zeros_initializer(),
                 hidden_state_init_trainable=False,
                 learn_std=True,
                 std_share_network=False,
                 init_std=1.0,
                 layer_normalization=False,
                 state_include_action=True):
        if not isinstance(env_spec.action_space, akro.Box):
            raise ValueError('GaussianGRUPolicy only works with '
                             'akro.Box action space, but not {}'.format(
                                 env_spec.action_space))
        super().__init__(name, env_spec)
        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim
        self._hidden_dims = hidden_dims
        self._state_include_action = state_include_action

        if state_include_action:
            self._input_dim = self._obs_dim + self._action_dim
        else:
            self._input_dim = self._obs_dim

        self.model = GaussianGRUModel(
            output_dim=self._action_dim,
            hidden_dims=hidden_dims,
            name='GaussianGRUModel',
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            recurrent_nonlinearity=recurrent_nonlinearity,
            recurrent_w_init=recurrent_w_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            hidden_state_init=hidden_state_init,
            hidden_state_init_trainable=hidden_state_init_trainable,
            layer_normalization=layer_normalization,
            learn_std=learn_std,
            std_share_network=std_share_network,
            init_std=init_std)

        self._prev_actions = None
        self._prev_hiddens = None
        self._initialize()
    def test_without_std_share_network_output_values(self, mock_normal,
                                                     output_dim, hidden_dim,
                                                     init_std):
        mock_normal.return_value = 0.5
        model = GaussianGRUModel(output_dim=output_dim,
                                 hidden_dim=hidden_dim,
                                 std_share_network=False,
                                 hidden_nonlinearity=None,
                                 recurrent_nonlinearity=None,
                                 hidden_w_init=self.default_initializer,
                                 recurrent_w_init=self.default_initializer,
                                 output_w_init=self.default_initializer,
                                 init_std=init_std)
        step_hidden_var = tf.compat.v1.placeholder(shape=(self.batch_size,
                                                          hidden_dim),
                                                   name='step_hidden',
                                                   dtype=tf.float32)
        (mean_var, step_mean_var, log_std_var, step_log_std_var, step_hidden,
         hidden_init_var, dist) = model.build(self.input_var,
                                              self.step_input_var,
                                              step_hidden_var)

        hidden1 = hidden2 = np.full((self.batch_size, hidden_dim),
                                    hidden_init_var.eval())

        mean, log_std = self.sess.run(
            [mean_var, log_std_var],
            feed_dict={self.input_var: self.obs_inputs})

        for i in range(self.time_step):
            mean1, log_std1, hidden1 = self.sess.run(
                [step_mean_var, step_log_std_var, step_hidden],
                feed_dict={
                    self.step_input_var: self.obs_input,
                    step_hidden_var: hidden1
                })

            hidden2 = recurrent_step_gru(input_val=self.obs_input,
                                         num_units=hidden_dim,
                                         step_hidden=hidden2,
                                         w_x_init=0.1,
                                         w_h_init=0.1,
                                         b_init=0.,
                                         nonlinearity=None,
                                         gate_nonlinearity=None)

            output_nonlinearity = np.full(
                (np.prod(hidden2.shape[1:]), output_dim), 0.1)
            output2 = np.matmul(hidden2, output_nonlinearity)
            assert np.allclose(mean1, output2)
            expected_log_std = np.full((self.batch_size, output_dim),
                                       np.log(init_std))
            assert np.allclose(log_std1, expected_log_std)
            assert np.allclose(hidden1, hidden2)
    def test_without_std_share_network_is_pickleable(self, mock_normal,
                                                     output_dim, hidden_dim):
        mock_normal.return_value = 0.5
        model = GaussianGRUModel(output_dim=output_dim,
                                 hidden_dim=hidden_dim,
                                 std_share_network=False,
                                 hidden_nonlinearity=None,
                                 recurrent_nonlinearity=None,
                                 hidden_w_init=self.default_initializer,
                                 recurrent_w_init=self.default_initializer,
                                 output_w_init=self.default_initializer)
        step_hidden_var = tf.compat.v1.placeholder(shape=(self.batch_size,
                                                          hidden_dim),
                                                   name='step_hidden',
                                                   dtype=tf.float32)
        (mean_var, step_mean_var, log_std_var, step_log_std_var, step_hidden,
         _, _) = model.build(self.input_var, self.step_input_var,
                             step_hidden_var)

        # output layer is a tf.keras.layers.Dense object,
        # which cannot be access by tf.compat.v1.variable_scope.
        # A workaround is to access in tf.compat.v1.global_variables()
        for var in tf.compat.v1.global_variables():
            if 'output_layer/bias' in var.name:
                var.load(tf.ones_like(var).eval())

        hidden = np.zeros((self.batch_size, hidden_dim))

        outputs1 = self.sess.run([mean_var, log_std_var],
                                 feed_dict={self.input_var: self.obs_inputs})
        output1 = self.sess.run([step_mean_var, step_log_std_var, step_hidden],
                                feed_dict={
                                    self.step_input_var: self.obs_input,
                                    step_hidden_var: hidden
                                })  # noqa: E126

        h = pickle.dumps(model)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            model_pickled = pickle.loads(h)

            input_var = tf.compat.v1.placeholder(tf.float32,
                                                 shape=(None, None,
                                                        self.feature_shape),
                                                 name='input')
            step_input_var = tf.compat.v1.placeholder(
                tf.float32,
                shape=(None, self.feature_shape),
                name='step_input')
            step_hidden_var = tf.compat.v1.placeholder(shape=(self.batch_size,
                                                              hidden_dim),
                                                       name='initial_hidden',
                                                       dtype=tf.float32)

            (mean_var2, step_mean_var2, log_std_var2, step_log_std_var2,
             step_hidden2, _, _) = model_pickled.build(input_var,
                                                       step_input_var,
                                                       step_hidden_var)

            outputs2 = sess.run([mean_var2, log_std_var2],
                                feed_dict={input_var: self.obs_inputs})
            output2 = sess.run(
                [step_mean_var2, step_log_std_var2, step_hidden2],
                feed_dict={
                    step_input_var: self.obs_input,
                    step_hidden_var: hidden
                })
            assert np.array_equal(outputs1, outputs2)
            assert np.array_equal(output1, output2)
Esempio n. 6
0
class GaussianGRUPolicy(StochasticPolicy):
    """Models the action distribution using a Gaussian parameterized by a GRU.

    Args:
        env_spec (metarl.envs.env_spec.EnvSpec): Environment specification.
        name (str): Model name, also the variable scope.
        hidden_dim (int): Hidden dimension for GRU cell for mean.
        hidden_nonlinearity (Callable): Activation function for intermediate
            dense layer(s). It should return a tf.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (Callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            tf.Tensor.
        hidden_b_init (Callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            tf.Tensor.
        recurrent_nonlinearity (Callable): Activation function for recurrent
            layers. It should return a tf.Tensor. Set it to None to
            maintain a linear activation.
        recurrent_w_init (Callable): Initializer function for the weight
            of recurrent layer(s). The function should return a
            tf.Tensor.
        output_nonlinearity (Callable): Activation function for output dense
            layer. It should return a tf.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (Callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            tf.Tensor.
        output_b_init (Callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            tf.Tensor.
        hidden_state_init (Callable): Initializer function for the
            initial hidden state. The functino should return a tf.Tensor.
        hidden_state_init_trainable (bool): Bool for whether the initial
            hidden state is trainable.
        learn_std (bool): Is std trainable.
        std_share_network (bool): Boolean for whether mean and std share
            the same network.
        init_std (float): Initial value for std.
        layer_normalization (bool): Bool for using layer normalization or not.
        state_include_action (bool): Whether the state includes action.
            If True, input dimension will be
            (observation dimension + action dimension).

    """
    def __init__(self,
                 env_spec,
                 hidden_dims=[32],
                 name='GaussianGRUPolicy',
                 hidden_nonlinearity=tf.nn.tanh,
                 hidden_w_init=tf.glorot_uniform_initializer(),
                 hidden_b_init=tf.zeros_initializer(),
                 recurrent_nonlinearity=tf.nn.sigmoid,
                 recurrent_w_init=tf.glorot_uniform_initializer(),
                 output_nonlinearity=None,
                 output_w_init=tf.glorot_uniform_initializer(),
                 output_b_init=tf.zeros_initializer(),
                 hidden_state_init=tf.zeros_initializer(),
                 hidden_state_init_trainable=False,
                 learn_std=True,
                 std_share_network=False,
                 init_std=1.0,
                 layer_normalization=False,
                 state_include_action=True):
        if not isinstance(env_spec.action_space, akro.Box):
            raise ValueError('GaussianGRUPolicy only works with '
                             'akro.Box action space, but not {}'.format(
                                 env_spec.action_space))
        super().__init__(name, env_spec)
        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim
        self._hidden_dims = hidden_dims
        self._state_include_action = state_include_action

        if state_include_action:
            self._input_dim = self._obs_dim + self._action_dim
        else:
            self._input_dim = self._obs_dim

        self.model = GaussianGRUModel(
            output_dim=self._action_dim,
            hidden_dims=hidden_dims,
            name='GaussianGRUModel',
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            recurrent_nonlinearity=recurrent_nonlinearity,
            recurrent_w_init=recurrent_w_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            hidden_state_init=hidden_state_init,
            hidden_state_init_trainable=hidden_state_init_trainable,
            layer_normalization=layer_normalization,
            learn_std=learn_std,
            std_share_network=std_share_network,
            init_std=init_std)

        self._prev_actions = None
        self._prev_hiddens = None
        self._initialize()

    def _initialize(self):
        obs_ph = tf.compat.v1.placeholder(tf.float32,
                                          shape=(None, None, self._input_dim))
        step_input_var = tf.compat.v1.placeholder(shape=(None,
                                                         self._input_dim),
                                                  name='step_input',
                                                  dtype=tf.float32)
        step_hidden_var = tf.compat.v1.placeholder(
            shape=(None, self._hidden_dims[0]),
            name='step_hidden_input',
            dtype=tf.float32)

        with tf.compat.v1.variable_scope(self.name) as vs:
            self._variable_scope = vs
            self.model.build(obs_ph, step_input_var, step_hidden_var)

        self._f_step_mean_std = (
            tf.compat.v1.get_default_session().make_callable(
                [
                    self.model.networks['default'].step_mean,
                    self.model.networks['default'].step_log_std,
                    self.model.networks['default'].step_hidden
                ],
                feed_list=[step_input_var, step_hidden_var]))

    @property
    def vectorized(self):
        """bool: Whether the policy is vectorized or not."""
        return True

    def dist_info_sym(self, obs_var, state_info_vars, name=None):
        """Build a symbolic graph of the distribution parameters.

        Args:
            obs_var (tf.Tensor): Tensor input for symbolic graph.
            state_info_vars (dict): Extra state information, e.g.
                previous action.
            name (str): Name for symbolic graph.

        Returns:
            dict[tf.Tensor]: Outputs of the symbolic graph of distribution
                parameters.

        """
        if self._state_include_action:
            prev_action_var = state_info_vars['prev_action']
            prev_action_var = tf.cast(prev_action_var, tf.float32)
            all_input_var = tf.concat(axis=2,
                                      values=[obs_var, prev_action_var])
        else:
            all_input_var = obs_var

        with tf.compat.v1.variable_scope(self._variable_scope):
            mean_var, _, log_std_var, _, _, _, _ = self.model.build(
                all_input_var,
                self.model.networks['default'].step_input,
                self.model.networks['default'].step_hidden_input,
                name=name)

        return dict(mean=mean_var, log_std=log_std_var)

    def reset(self, dones=None):
        """Reset the policy.

        Note:
            If `dones` is None, it will be by default `np.array([True])` which
            implies the policy will not be "vectorized", i.e. number of
            parallel environments for training data sampling = 1.

        Args:
            dones (numpy.ndarray): Bool that indicates terminal state(s).

        """
        if dones is None:
            dones = np.array([True])
        if self._prev_actions is None or len(dones) != len(self._prev_actions):
            self._prev_actions = np.zeros(
                (len(dones), self.action_space.flat_dim))
            self._prev_hiddens = np.zeros((len(dones), self._hidden_dims[0]))

        self._prev_actions[dones] = 0.
        self._prev_hiddens[dones] = self.model.networks[
            'default'].init_hidden.eval()

    def get_action(self, observation):
        """Get a single action from this policy for the input observation.

        Args:
            observation (numpy.ndarray): Observation from environment.

        Returns:
            tuple[numpy.ndarray, dict]: Predicted action and agent info.

                action (numpy.ndarray): Predicted action.
                agent_info (dict): Distribution obtained after observing the
                    given observation, with keys
                    * mean: (numpy.ndarray)
                    * log_std: (numpy.ndarray)
                    * prev_action: (numpy.ndarray), only present if
                        self._state_include_action is True.

        """
        actions, agent_infos = self.get_actions([observation])
        return actions[0], {k: v[0] for k, v in agent_infos.items()}

    def get_actions(self, observations):
        """Get multiple actions from this policy for the input observations.

        Args:
            observations (numpy.ndarray): Observations from environment.

        Returns:
            tuple[numpy.ndarray, dict]: Prediction actions and agent infos.

                actions (numpy.ndarray): Predicted actions.
                agent_infos (dict): Distribution obtained after observing the
                    given observation, with keys
                    * mean: (numpy.ndarray)
                    * log_std: (numpy.ndarray)
                    * prev_action: (numpy.ndarray), only present if
                        self._state_include_action is True.

        """
        # flat_obs = self.observation_space.flatten_n(observations)
        if self._state_include_action:
            assert self._prev_actions is not None
            all_input = np.concatenate([observations, self._prev_actions],
                                       axis=-1)
        else:
            all_input = observations
        means, log_stds, hidden_vec = self._f_step_mean_std(
            all_input, self._prev_hiddens)
        rnd = np.random.normal(size=means.shape)
        samples = rnd * np.exp(log_stds) + means
        # samples = self.action_space.unflatten_n(samples)
        prev_actions = self._prev_actions
        self._prev_actions = samples
        self._prev_hiddens = hidden_vec
        agent_infos = dict(mean=means, log_std=log_stds)
        if self._state_include_action:
            agent_infos['prev_action'] = np.copy(prev_actions)
        return samples, agent_infos

    @property
    def recurrent(self):
        """bool: Whether this policy is recurrent or not."""
        return True

    @property
    def distribution(self):
        """metarl.tf.distributions.DiagonalGaussian: Policy distribution."""
        return self.model.networks['default'].dist

    @property
    def state_info_specs(self):
        """list: State info specification."""
        if self._state_include_action:
            return [
                ('prev_action', (self._action_dim, )),
            ]

        return []

    def __getstate__(self):
        """See `Object.__getstate__`."""
        new_dict = super().__getstate__()
        del new_dict['_f_step_mean_std']
        return new_dict

    def __setstate__(self, state):
        """See `Object.__setstate__`."""
        super().__setstate__(state)
        self._initialize()
Esempio n. 7
0
class GaussianGRUPolicy(StochasticPolicy):
    """Gaussian GRU Policy.

    A policy represented by a Gaussian distribution
    which is parameterized by a Gated Recurrent Unit (GRU).

    Args:
        env_spec (metarl.envs.env_spec.EnvSpec): Environment specification.
        name (str): Model name, also the variable scope.
        hidden_dim (int): Hidden dimension for GRU cell for mean.
        hidden_nonlinearity (Callable): Activation function for intermediate
            dense layer(s). It should return a tf.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (Callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            tf.Tensor.
        hidden_b_init (Callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            tf.Tensor.
        recurrent_nonlinearity (Callable): Activation function for recurrent
            layers. It should return a tf.Tensor. Set it to None to
            maintain a linear activation.
        recurrent_w_init (Callable): Initializer function for the weight
            of recurrent layer(s). The function should return a
            tf.Tensor.
        output_nonlinearity (Callable): Activation function for output dense
            layer. It should return a tf.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (Callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            tf.Tensor.
        output_b_init (Callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            tf.Tensor.
        hidden_state_init (Callable): Initializer function for the
            initial hidden state. The functino should return a tf.Tensor.
        hidden_state_init_trainable (bool): Bool for whether the initial
            hidden state is trainable.
        learn_std (bool): Is std trainable.
        std_share_network (bool): Boolean for whether mean and std share
            the same network.
        init_std (float): Initial value for std.
        layer_normalization (bool): Bool for using layer normalization or not.
        state_include_action (bool): Whether the state includes action.
            If True, input dimension will be
            (observation dimension + action dimension).

    """

    def __init__(self,
                 env_spec,
                 hidden_dim=32,
                 name='GaussianGRUPolicy',
                 hidden_nonlinearity=tf.nn.tanh,
                 hidden_w_init=tf.initializers.glorot_uniform(),
                 hidden_b_init=tf.zeros_initializer(),
                 recurrent_nonlinearity=tf.nn.sigmoid,
                 recurrent_w_init=tf.initializers.glorot_uniform(),
                 output_nonlinearity=None,
                 output_w_init=tf.initializers.glorot_uniform(),
                 output_b_init=tf.zeros_initializer(),
                 hidden_state_init=tf.zeros_initializer(),
                 hidden_state_init_trainable=False,
                 learn_std=True,
                 std_share_network=False,
                 init_std=1.0,
                 layer_normalization=False,
                 state_include_action=True):
        if not isinstance(env_spec.action_space, akro.Box):
            raise ValueError('GaussianGRUPolicy only works with '
                             'akro.Box action space, but not {}'.format(
                                 env_spec.action_space))
        super().__init__(name, env_spec)
        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim

        self._hidden_dim = hidden_dim
        self._hidden_nonlinearity = hidden_nonlinearity
        self._hidden_w_init = hidden_w_init
        self._hidden_b_init = hidden_b_init
        self._recurrent_nonlinearity = recurrent_nonlinearity
        self._recurrent_w_init = recurrent_w_init
        self._output_nonlinearity = output_nonlinearity
        self._output_w_init = output_w_init
        self._output_b_init = output_b_init
        self._hidden_state_init = hidden_state_init
        self._hidden_state_init_trainable = hidden_state_init_trainable
        self._learn_std = learn_std
        self._std_share_network = std_share_network
        self._init_std = init_std
        self._layer_normalization = layer_normalization
        self._state_include_action = state_include_action

        if state_include_action:
            self._input_dim = self._obs_dim + self._action_dim
        else:
            self._input_dim = self._obs_dim

        self._f_step_mean_std = None

        self.model = GaussianGRUModel(
            output_dim=self._action_dim,
            hidden_dim=hidden_dim,
            name='GaussianGRUModel',
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            recurrent_nonlinearity=recurrent_nonlinearity,
            recurrent_w_init=recurrent_w_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            hidden_state_init=hidden_state_init,
            hidden_state_init_trainable=hidden_state_init_trainable,
            layer_normalization=layer_normalization,
            learn_std=learn_std,
            std_share_network=std_share_network,
            init_std=init_std)

        self._prev_actions = None
        self._prev_hiddens = None

    def build(self, state_input, name=None):
        """Build model.

        Args:
          state_input (tf.Tensor): State input.
          name (str): Name of the model, which is also the name scope.

        """
        with tf.compat.v1.variable_scope(self.name) as vs:
            self._variable_scope = vs
            step_input_var = tf.compat.v1.placeholder(shape=(None,
                                                             self._input_dim),
                                                      name='step_input',
                                                      dtype=tf.float32)
            step_hidden_var = tf.compat.v1.placeholder(
                shape=(None, self._hidden_dim),
                name='step_hidden_input',
                dtype=tf.float32)
            self.model.build(state_input,
                             step_input_var,
                             step_hidden_var,
                             name=name)

        self._f_step_mean_std = (
            tf.compat.v1.get_default_session().make_callable(
                [
                    self.model.networks['default'].step_mean,
                    self.model.networks['default'].step_log_std,
                    self.model.networks['default'].step_hidden
                ],
                feed_list=[step_input_var, step_hidden_var]))

    @property
    def vectorized(self):
        """Vectorized or not.

        Returns:
            Bool: True if primitive supports vectorized operations.

        """
        return True

    def reset(self, do_resets=None):
        """Reset the policy.

        Note:
            If `do_resets` is None, it will be by default `np.array([True])`
            which implies the policy will not be "vectorized", i.e. number of
            parallel environments for training data sampling = 1.

        Args:
            do_resets (numpy.ndarray): Bool that indicates terminal state(s).

        """
        if do_resets is None:
            do_resets = np.array([True])
        if self._prev_actions is None or len(do_resets) != len(
                self._prev_actions):
            self._prev_actions = np.zeros(
                (len(do_resets), self.action_space.flat_dim))
            self._prev_hiddens = np.zeros((len(do_resets), self._hidden_dim))

        self._prev_actions[do_resets] = 0.
        self._prev_hiddens[do_resets] = self.model.networks[
            'default'].init_hidden.eval()

    def get_action(self, observation):
        """Get single action from this policy for the input observation.

        Args:
            observation (numpy.ndarray): Observation from environment.

        Returns:
            numpy.ndarray: Actions
            dict: Predicted action and agent information.

        Note:
            It returns an action and a dict, with keys
            - mean (numpy.ndarray): Mean of the distribution.
            - log_std (numpy.ndarray): Log standard deviation of the
                distribution.
            - prev_action (numpy.ndarray): Previous action, only present if
                self._state_include_action is True.

        """
        actions, agent_infos = self.get_actions([observation])
        return actions[0], {k: v[0] for k, v in agent_infos.items()}

    def get_actions(self, observations):
        """Get multiple actions from this policy for the input observations.

        Args:
            observations (numpy.ndarray): Observations from environment.

        Returns:
            numpy.ndarray: Actions
            dict: Predicted action and agent information.

        Note:
            It returns an action and a dict, with keys
            - mean (numpy.ndarray): Means of the distribution.
            - log_std (numpy.ndarray): Log standard deviations of the
                distribution.
            - prev_action (numpy.ndarray): Previous action, only present if
                self._state_include_action is True.

        """
        if self._state_include_action:
            assert self._prev_actions is not None
            all_input = np.concatenate([observations, self._prev_actions],
                                       axis=-1)
        else:
            all_input = observations
        means, log_stds, hidden_vec = self._f_step_mean_std(
            all_input, self._prev_hiddens)
        rnd = np.random.normal(size=means.shape)
        samples = rnd * np.exp(log_stds) + means
        samples = self.action_space.unflatten_n(samples)
        prev_actions = self._prev_actions
        self._prev_actions = samples
        self._prev_hiddens = hidden_vec
        agent_infos = dict(mean=means, log_std=log_stds)
        if self._state_include_action:
            agent_infos['prev_action'] = np.copy(prev_actions)
        return samples, agent_infos

    @property
    def distribution(self):
        """Policy distribution.

        Returns:
            tfp.Distribution.MultivariateNormalDiag: Policy distribution.

        """
        return self.model.networks['default'].dist

    @property
    def state_info_specs(self):
        """State info specifcation.

        Returns:
            List[str]: keys and shapes for the information related to the
                policy's state when taking an action.

        """
        if self._state_include_action:
            return [
                ('prev_action', (self._action_dim, )),
            ]

        return []

    def clone(self, name):
        """Return a clone of the policy.

        It only copies the configuration of the primitive,
        not the parameters.

        Args:
            name (str): Name of the newly created policy. It has to be
                different from source policy if cloned under the same
                computational graph.

        Returns:
            metarl.tf.policies.GaussianGRUPolicy: Newly cloned policy.

        """
        return self.__class__(
            name=name,
            env_spec=self._env_spec,
            hidden_dim=self._hidden_dim,
            hidden_nonlinearity=self._hidden_nonlinearity,
            hidden_w_init=self._hidden_w_init,
            hidden_b_init=self._hidden_b_init,
            recurrent_nonlinearity=self._recurrent_nonlinearity,
            recurrent_w_init=self._recurrent_w_init,
            output_nonlinearity=self._output_nonlinearity,
            output_w_init=self._output_w_init,
            output_b_init=self._output_b_init,
            hidden_state_init=self._hidden_state_init,
            hidden_state_init_trainable=self._hidden_state_init_trainable,
            learn_std=self._learn_std,
            std_share_network=self._std_share_network,
            init_std=self._init_std,
            layer_normalization=self._layer_normalization,
            state_include_action=self._state_include_action)

    def __getstate__(self):
        """Object.__getstate__.

        Returns:
            dict: the state to be pickled for the instance.

        """
        new_dict = super().__getstate__()
        del new_dict['_f_step_mean_std']
        return new_dict