def __init__(self, logits: Optional[Numeric] = None, probs: Optional[Numeric] = None, dtype: jnp.dtype = jnp.int_): """Initializes a Bernoulli distribution. Args: logits: Logit transform of the probability of a `1` event (`0` otherwise), i.e. `probs = sigmoid(logits)`. Only one of `logits` or `probs` can be specified. probs: Probability of a `1` event (`0` otherwise). Only one of `logits` or `probs` can be specified. dtype: The type of event samples. """ super().__init__() # Validate arguments. chex.assert_exactly_one_is_none(probs, logits) if not (jnp.issubdtype(dtype, bool) or jnp.issubdtype( dtype, jnp.integer) or jnp.issubdtype(dtype, jnp.floating)): raise ValueError( f'The dtype of `{self.name}` must be boolean, integer or ' f'floating-point, instead got `{dtype}`.') # Parameters of the distribution. self._probs = None if probs is None else conversion.as_float_array( probs) self._logits = None if logits is None else conversion.as_float_array( logits) self._dtype = dtype
def __init__(self, low: Numeric = 0., high: Numeric = 1.): """Initializes a Uniform distribution. Args: low: Lower bound. high: Upper bound. """ super().__init__() self._low = conversion.as_float_array(low) self._high = conversion.as_float_array(high) self._batch_shape = jax.lax.broadcast_shapes(self._low.shape, self._high.shape)
def __init__(self, loc: Numeric, scale: Numeric): """Initializes a Normal distribution. Args: loc: Mean of the distribution. scale: Standard deviation of the distribution. """ super().__init__() self._loc = conversion.as_float_array(loc) self._scale = conversion.as_float_array(scale) self._batch_shape = jax.lax.broadcast_shapes( self._loc.shape, self._scale.shape)
def __init__(self, loc: Numeric, scale: Numeric) -> None: """Initializes a Logistic distribution. Args: loc: Mean of the distribution. scale: Spread of the distribution. """ super().__init__() self._loc = conversion.as_float_array(loc) self._scale = conversion.as_float_array(scale) self._batch_shape = jax.lax.broadcast_shapes(self._loc.shape, self._scale.shape)
def __init__(self, loc: Optional[Array] = None, scale_diag: Optional[Array] = None): """Initializes a MultivariateNormalDiag distribution. Args: loc: Mean vector of the distribution. Can also be a batch of vectors. If not specified, it defaults to zeros. At least one of `loc` and `scale_diag` must be specified. scale_diag: Vector of standard deviations. Can also be a batch of vectors. If not specified, it defaults to ones. At least one of `loc` and `scale_diag` must be specified. """ super().__init__() chex.assert_not_both_none(loc, scale_diag) if scale_diag is not None and not scale_diag.shape: raise ValueError( 'If provided, argument `scale_diag` must have at least ' '1 dimension.') if loc is not None and not loc.shape: raise ValueError('If provided, argument `loc` must have at least ' '1 dimension.') if loc is not None and scale_diag is not None and ( loc.shape[-1] != scale_diag.shape[-1]): raise ValueError( f'The last dimension of arguments `loc` and ' f'`scale_diag` must coincide, but {loc.shape[-1]} != ' f'{scale_diag.shape[-1]}.') if scale_diag is None: self._loc = conversion.as_float_array(loc) self._scale_diag = jnp.ones(self._loc.shape[-1], self._loc.dtype) elif loc is None: self._scale_diag = conversion.as_float_array(scale_diag) self._loc = jnp.zeros(self._scale_diag.shape[-1], self._scale_diag.dtype) else: self._loc = conversion.as_float_array(loc) self._scale_diag = conversion.as_float_array(scale_diag) self._batch_shape = jax.lax.broadcast_shapes( self._loc.shape[:-1], self._scale_diag.shape[:-1])
def __init__(self, loc: Numeric, log_scale: Numeric, max_scale: Optional[float] = None): """Initializes a LogStddevNormal distribution. Args: loc: Mean of the distribution. log_scale: Log of the distribution's scale. This is often the pre-activated output of a neural network. max_scale: Maximum value of the scale that this distribution will saturate at. This parameter can be useful for numerical stability. It is not a hard maximum; rather, we compute scale as per the following formula: log(max_scale) - softplus(log(max_scale) - log_scale). """ if max_scale is not None: max_log_scale = math.log(max_scale) self._log_scale = max_log_scale - jax.nn.softplus( max_log_scale - conversion.as_float_array(log_scale)) else: self._log_scale = conversion.as_float_array(log_scale) scale = jnp.exp(self._log_scale) super().__init__(loc, scale)
def test_on_invalid_array(self, dtype): x = jnp.zeros([], dtype) with self.assertRaises(ValueError): conversion.as_float_array(x)
def test_on_int_array(self, dtype): x = jnp.zeros([], dtype) y = conversion.as_float_array(x) self.assertIsInstance(y, jnp.ndarray) self.assertEqual( y.dtype, jnp.float64 if jax.config.x64_enabled else jnp.float32)
def test_on_float_array(self, dtype): x = jnp.zeros([], dtype) y = conversion.as_float_array(x) self.assertIs(y, x)
def test_on_invalid_scalar(self, x): with self.assertRaises(ValueError): conversion.as_float_array(x)
def test_on_valid_scalar(self, x): y = conversion.as_float_array(x) self.assertIsInstance(y, jnp.ndarray) self.assertEqual( y.dtype, jnp.float64 if jax.config.x64_enabled else jnp.float32)