示例#1
0
def categorical_kl_divergence(p_logits: Array,
                              q_logits: Array,
                              temperature: float = 1.) -> Array:
    """Compute the KL between two categorical distributions from their logits.

  Args:
    p_logits: unnormalized logits for the first distribution.
    q_logits: unnormalized logits for the second distribution.
    temperature: the temperature for the softmax distribution, defaults at 1.

  Returns:
    the kl divergence between the distributions.
  """
    warnings.warn(
        "Rlax categorical_kl_divergence will be deprecated. "
        "Please use distrax.Softmax.kl_divergence instead.",
        PendingDeprecationWarning,
        stacklevel=2)
    return distrax.Softmax(p_logits, temperature).kl_divergence(
        distrax.Softmax(q_logits, temperature))
示例#2
0
 def entropy_fn(logits: Array):
     return jnp.minimum(
         distrax.Softmax(logits, temperature).entropy(),
         entropy_clip * jnp.log(logits.shape[-1]))
示例#3
0
 def logprob_fn(sample: Array, logits: Array, action_spec=None):
     del action_spec
     return distrax.Softmax(logits, temperature).log_prob(sample)
示例#4
0
 def probs_fn(logits: Array, action_spec=None):
     del action_spec
     return distrax.Softmax(logits, temperature).probs
示例#5
0
 def sample_fn(key: Array, logits: Array, action_spec=None):
     del action_spec
     return distrax.Softmax(logits, temperature).sample(seed=key)
示例#6
0
 def entropy_fn(logits: Array):
     return distrax.Softmax(logits, temperature).entropy()
示例#7
0
 def logprob_fn(sample: Array, logits: Array):
     return distrax.Softmax(logits, temperature).log_prob(sample)
示例#8
0
 def probs_fn(logits: Array):
     return distrax.Softmax(logits, temperature).probs
示例#9
0
 def sample_fn(key: Array, logits: Array):
     return distrax.Softmax(logits, temperature).sample(seed=key)
示例#10
0
 def entropy_fn(logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).entropy()
示例#11
0
 def log_prob_fn(sample: Array, logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).log_prob(sample)
示例#12
0
 def sample_fn(key: Array, logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).sample(seed=key)