def categorical_kl_divergence(p_logits: Array, q_logits: Array, temperature: float = 1.) -> Array: """Compute the KL between two categorical distributions from their logits. Args: p_logits: unnormalized logits for the first distribution. q_logits: unnormalized logits for the second distribution. temperature: the temperature for the softmax distribution, defaults at 1. Returns: the kl divergence between the distributions. """ warnings.warn( "Rlax categorical_kl_divergence will be deprecated. " "Please use distrax.Softmax.kl_divergence instead.", PendingDeprecationWarning, stacklevel=2) return distrax.Softmax(p_logits, temperature).kl_divergence( distrax.Softmax(q_logits, temperature))
def entropy_fn(logits: Array): return jnp.minimum( distrax.Softmax(logits, temperature).entropy(), entropy_clip * jnp.log(logits.shape[-1]))
def logprob_fn(sample: Array, logits: Array, action_spec=None): del action_spec return distrax.Softmax(logits, temperature).log_prob(sample)
def probs_fn(logits: Array, action_spec=None): del action_spec return distrax.Softmax(logits, temperature).probs
def sample_fn(key: Array, logits: Array, action_spec=None): del action_spec return distrax.Softmax(logits, temperature).sample(seed=key)
def entropy_fn(logits: Array): return distrax.Softmax(logits, temperature).entropy()
def logprob_fn(sample: Array, logits: Array): return distrax.Softmax(logits, temperature).log_prob(sample)
def probs_fn(logits: Array): return distrax.Softmax(logits, temperature).probs
def sample_fn(key: Array, logits: Array): return distrax.Softmax(logits, temperature).sample(seed=key)
def entropy_fn(logits: Array): probs = distrax.Softmax(logits=logits, temperature=temperature).probs return distrax.Categorical( probs=_mix_with_uniform(probs, epsilon)).entropy()
def log_prob_fn(sample: Array, logits: Array): probs = distrax.Softmax(logits=logits, temperature=temperature).probs return distrax.Categorical( probs=_mix_with_uniform(probs, epsilon)).log_prob(sample)
def sample_fn(key: Array, logits: Array): probs = distrax.Softmax(logits=logits, temperature=temperature).probs return distrax.Categorical( probs=_mix_with_uniform(probs, epsilon)).sample(seed=key)