def log_prob(self, inputs, point): point = point.reshape(inputs.shape[:-1] + (-1, )) return ( # L2 term. -jnp.sum((point - inputs)**2, axis=-1) / (2 * self._std**2) - # Normalizing constant. ((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi))) * jnp.prod(self._shape)))
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) b = jnp.sum(jnp.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) y = jnp.linalg.solve(sigma, x - mu) y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def entropy(self, log_probs): del log_probs # would be helpful if self._std was learnable return jnp.exp(self._std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e)
def entropy(self): return jnp.exp(self._std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e)
def entropy(self): return np.exp(self._std) + .5 * np.log(2.0 * np.pi * np.e)