示例#1
0
    def _action(self, time_step, policy_state, seed):
        """Implementation of `action`.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.
      seed: Seed to use if action performs sampling (optional).

    Returns:
      A `PolicyStep` named tuple containing:
        `action`: An action Tensor matching the `action_spec`.
        `state`: A policy state tensor to be fed into the next call to action.
        `info`: Optional side information such as action log probabilities.
    """
        seed_stream = tfp.util.SeedStream(seed=seed,
                                          salt='tf_agents_tf_policy')
        distribution_step = self._distribution(time_step, policy_state)
        actions = tf.nest.map_structure(
            lambda d: reparameterized_sampling.sample(d, seed=seed_stream()),
            distribution_step.action)
        info = distribution_step.info
        if self.emit_log_probability:
            try:
                log_probability = tf.nest.map_structure(
                    lambda a, d: d.log_prob(a), actions,
                    distribution_step.action)
                info = policy_step.set_log_probability(info, log_probability)
            except:
                raise TypeError(
                    '%s does not support emitting log-probabilities.' %
                    type(self).__name__)

        return distribution_step._replace(action=actions, info=info)
    def _action(self, time_step, policy_state, seed=1):
        """Implementation of `action`.

        Args:
            time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
            policy_state: A Tensor, or a nested dict, list or tuple of Tensors
                representing the previous policy_state.
            seed: Seed to use if action performs sampling (optional).

        Returns:
            A `PolicyStep` named tuple containing:
                `action`: An action Tensor matching the `action_spec()`.
                `state`: A policy state tensor to be fed into the next call.
                `info`: Optional information such as action log probabilities.
        """
        if self._t % self._steps_per_option == 0:
            seed_stream = tfp.util.SeedStream(seed=seed, salt='ppo_policy')
            distribution_step = self._latent_distribution(
                time_step, policy_state)
            latent_actions = tf.nest.map_structure(
                lambda d: reparameterized_sampling.sample(d,
                                                          seed=seed_stream()),
                distribution_step.action)
            # policy_state = (distribution_step, latent_actions)
        self._t += 1
        # (distribution_step, latent_actions) = policy_state
        action_distribution, _ = self._generator_network(
            OrderedDict({
                "observation": time_step.observation,
                "z": latent_actions
            }), time_step.step_type, policy_state)
        if self.emit_log_probability:
            raise NotImplementedError
        info = distribution_step.info
        actions = tf.nest.map_structure(lambda d: d.sample(),
                                        action_distribution)
        return distribution_step._replace(action=actions,
                                          info=info,
                                          state=policy_state)
    def _distribution(self, time_step, policy_state, seed=1):
        seed_stream = tfp.util.SeedStream(seed=seed, salt='ppo_policy')
        distribution_step = self._latent_distribution(time_step, policy_state)
        latent_actions = tf.nest.map_structure(
            lambda d: reparameterized_sampling.sample(d, seed=seed_stream()),
            distribution_step.action)
        action_distribution, _ = self._generator_network(
            OrderedDict({
                "observation": time_step.observation,
                "z": latent_actions
            }), time_step.step_type, policy_state)

        def _to_distribution(action_or_distribution):
            if isinstance(action_or_distribution, tf.Tensor):
                # This is an action tensor, so wrap it in a deterministic
                # distribution.
                return tfp.distributions.Deterministic(
                    loc=action_or_distribution)
            return action_or_distribution

        distributions = tf.nest.map_structure(_to_distribution,
                                              action_distribution)
        return policy_step.PolicyStep(distributions, policy_state)