def _distribution(self, time_step, policy_state):
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of
        Tensors representing the previous policy_state.

    Returns:
      A tfp.distributions.Categorical capturing the distribution of next
        actions.
      A policy_state Tensor, or a nested dict, list or tuple of Tensors,
        representing the new policy state.
    """
        q_logits, policy_state = self._q_network(time_step.observation,
                                                 time_step.step_type,
                                                 policy_state)
        q_logits.shape.assert_has_rank(3)
        q_values = common.convert_q_logits_to_values(q_logits, self._support)

        logits = q_values
        mask_split_fn = self._q_network.mask_split_fn

        if mask_split_fn:
            _, mask = mask_split_fn(time_step.observation)
            # Overwrite the logits for invalid actions to -inf.
            neg_inf = tf.constant(-np.inf, dtype=logits.dtype)
            logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits,
                                        neg_inf)

        dist = tfp.distributions.Categorical(logits=logits,
                                             dtype=self.action_spec.dtype)
        return policy_step.PolicyStep(dist, policy_state)
Beispiel #2
0
 def testConvertQLogitsToValuesBatch(self):
     logits = tf.constant([[[1., 20., 1.], [1., 1., 20.]],
                           [[20., 1., 1.], [1., 20., 20.]]])
     support = tf.constant([10., 20., 30.])
     values = common.convert_q_logits_to_values(logits, support)
     values_ = self.evaluate(values)
     self.assertAllClose(values_, [[20.0, 30.0], [10., 25.]], 0.001)
  def _distribution(self, time_step, policy_state):
    """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of
        Tensors representing the previous policy_state.

    Returns:
      A tfp.distributions.Categorical capturing the distribution of next
        actions.
      A policy_state Tensor, or a nested dict, list or tuple of Tensors,
        representing the new policy state.
    """
    network_observation = time_step.observation
    observation_and_action_constraint_splitter = (
        self.observation_and_action_constraint_splitter)

    if observation_and_action_constraint_splitter is not None:
      network_observation, mask = (
          observation_and_action_constraint_splitter(network_observation))

    q_logits, policy_state = self._q_network(
        network_observation, step_type=time_step.step_type,
        network_state=policy_state)
    q_logits.shape.assert_has_rank(3)
    q_values = common.convert_q_logits_to_values(q_logits, self._support)

    logits = q_values

    if observation_and_action_constraint_splitter is not None:
      # Overwrite the logits for invalid actions to -inf.
      neg_inf = tf.constant(-np.inf, dtype=logits.dtype)
      logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf)

    dist = tfp.distributions.Categorical(
        logits=logits, dtype=self.action_spec.dtype)  # pytype: disable=attribute-error
    return policy_step.PolicyStep(dist, policy_state)
Beispiel #4
0
    def _distribution(self, time_step, policy_state):
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of
        Tensors representing the previous policy_state.

    Returns:
      A tfp.distributions.Categorical capturing the distribution of next
        actions.
      A policy_state Tensor, or a nested dict, list or tuple of Tensors,
        representing the new policy state.
    """
        q_logits, policy_state = self._q_network(time_step.observation,
                                                 time_step.step_type,
                                                 policy_state)
        q_logits.shape.assert_has_rank(3)
        q_values = common.convert_q_logits_to_values(q_logits, self._support)
        return policy_step.PolicyStep(
            tfp.distributions.Categorical(logits=q_values,
                                          dtype=self.action_spec.dtype),
            policy_state)