Esempio n. 1
0
    def testActionProbabilities(self, observation_shape, batch_size, weights,
                                inverse_temperature, seed):
        observation_spec = tensor_spec.TensorSpec(shape=observation_shape,
                                                  dtype=tf.float32,
                                                  name='observation_spec')
        time_step_spec = time_step.time_step_spec(observation_spec)
        action_spec = tensor_spec.BoundedTensorSpec(
            shape=(),
            dtype=tf.int32,
            minimum=0,
            maximum=tf.compat.dimension_value(weights.shape[0]) - 1,
            name='action')
        policy = categorical_policy.CategoricalPolicy(weights, time_step_spec,
                                                      action_spec,
                                                      inverse_temperature)
        observation_step = _get_dummy_observation_step(observation_shape,
                                                       batch_size)
        action_time_step = policy.action(observation_step, seed=seed)

        logits = inverse_temperature * weights
        z = tf.reduce_logsumexp(logits)
        expected_logprob = logits - z
        expected_action_prob = tf.exp(
            tf.gather(expected_logprob, action_time_step.action))
        actual_action_prob = tf.exp(
            policy_step.get_log_probability(action_time_step.info))
        expected_action_prob_val, actual_action_prob_val = self.evaluate(
            [expected_action_prob, actual_action_prob])
        self.assertAllClose(expected_action_prob_val, actual_action_prob_val)
Esempio n. 2
0
  def distribution(self, time_step, policy_state=()):
    """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.

    Returns:
      A `PolicyStep` named tuple containing:

        `action`: A tf.distribution capturing the distribution of next actions.
        `state`: A policy state tensor for the next call to distribution.
        `info`: Optional side information such as action log probabilities.
    """
    tf.nest.assert_same_structure(time_step, self._time_step_spec)
    tf.nest.assert_same_structure(policy_state, self._policy_state_spec)
    if self._automatic_state_reset:
      policy_state = self._maybe_reset_state(time_step, policy_state)
    step = self._distribution(time_step=time_step, policy_state=policy_state)
    if self.emit_log_probability:
      # This here is set only for compatibility with info_spec in constructor.
      info = policy_step.set_log_probability(
          step.info,
          tf.nest.map_structure(
              lambda _: tf.constant(0., dtype=tf.float32),
              policy_step.get_log_probability(self._info_spec)))
      step = step._replace(info=info)
    tf.nest.assert_same_structure(step, self._policy_step_spec)
    return step
Esempio n. 3
0
    def _train(self, experience, weights=None):
        """Updates the policy based on the data in `experience`.

    Note that `experience` should only contain data points that this agent has
    not previously seen. If `experience` comes from a replay buffer, this buffer
    should be cleared between each call to `train`.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`.
      weights: Unused.

    Returns:
      A `LossInfo` containing the loss *before* the training step is taken.
        Note that the loss does not depend on policy state and comes directly
        from the experience (and is therefore not differentiable).

        In most cases, if `weights` is provided, the entries of this tuple will
        have been calculated with the weights.  Note that each Agent chooses
        its own method of applying weights.
    """
        del weights  # unused
        reward = experience.reward
        log_prob = policy_step.get_log_probability(experience.policy_info)
        action = experience.action
        update_value = exp3_update_value(reward, log_prob)
        weight_update = selective_sum(values=update_value,
                                      partitions=action,
                                      num_partitions=self.num_actions)
        tf.compat.v1.assign_add(self._weights, weight_update)

        batch_size = tf.cast(tf.size(reward), dtype=tf.int64)
        self._train_step_counter.assign_add(batch_size)

        return tf_agent.LossInfo(loss=-tf.reduce_sum(experience.reward),
                                 extra=())
Esempio n. 4
0
    def distribution(
        self, time_step: ts.TimeStep, policy_state: types.NestedTensor = ()
    ) -> policy_step.PolicyStep:
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.

    Returns:
      A `PolicyStep` named tuple containing:

        `action`: A tf.distribution capturing the distribution of next actions.
        `state`: A policy state tensor for the next call to distribution.
        `info`: Optional side information such as action log probabilities.

    Raises:
      ValueError or TypeError: If `validate_args is True` and inputs or
        outputs do not match `time_step_spec`, `policy_state_spec`,
        or `policy_step_spec`.
    """
        if self._validate_args:
            time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                    time_step)
            policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                       policy_state)
            nest_utils.assert_same_structure(
                time_step,
                self._time_step_spec,
                message='time_step and time_step_spec structures do not match')
            nest_utils.assert_same_structure(
                policy_state,
                self._policy_state_spec,
                message=
                'policy_state and policy_state_spec structures do not match')
        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = self._distribution(time_step=time_step,
                                  policy_state=policy_state)
        if self.emit_log_probability:
            # This here is set only for compatibility with info_spec in constructor.
            info = policy_step.set_log_probability(
                step.info,
                tf.nest.map_structure(
                    lambda _: tf.constant(0., dtype=tf.float32),
                    policy_step.get_log_probability(self._info_spec)))
            step = step._replace(info=info)
        if self._validate_args:
            nest_utils.assert_same_structure(
                step,
                self._policy_step_spec,
                message=('distribution output and policy_step_spec structures '
                         'do not match'))
        return step