Example #1
0
    def __init__(self,
                 logits: Optional[Numeric] = None,
                 probs: Optional[Numeric] = None,
                 dtype: jnp.dtype = jnp.int_):
        """Initializes a Bernoulli distribution.

    Args:
      logits: Logit transform of the probability of a `1` event (`0` otherwise),
        i.e. `probs = sigmoid(logits)`. Only one of `logits` or `probs` can be
        specified.
      probs: Probability of a `1` event (`0` otherwise). Only one of `logits` or
        `probs` can be specified.
      dtype: The type of event samples.
    """
        super().__init__()
        # Validate arguments.
        chex.assert_exactly_one_is_none(probs, logits)
        if not (jnp.issubdtype(dtype, bool) or jnp.issubdtype(
                dtype, jnp.integer) or jnp.issubdtype(dtype, jnp.floating)):
            raise ValueError(
                f'The dtype of `{self.name}` must be boolean, integer or '
                f'floating-point, instead got `{dtype}`.')
        # Parameters of the distribution.
        self._probs = None if probs is None else conversion.as_float_array(
            probs)
        self._logits = None if logits is None else conversion.as_float_array(
            logits)
        self._dtype = dtype
Example #2
0
    def __init__(self, low: Numeric = 0., high: Numeric = 1.):
        """Initializes a Uniform distribution.

    Args:
      low: Lower bound.
      high: Upper bound.
    """
        super().__init__()
        self._low = conversion.as_float_array(low)
        self._high = conversion.as_float_array(high)
        self._batch_shape = jax.lax.broadcast_shapes(self._low.shape,
                                                     self._high.shape)
Example #3
0
  def __init__(self, loc: Numeric, scale: Numeric):
    """Initializes a Normal distribution.

    Args:
      loc: Mean of the distribution.
      scale: Standard deviation of the distribution.
    """
    super().__init__()
    self._loc = conversion.as_float_array(loc)
    self._scale = conversion.as_float_array(scale)
    self._batch_shape = jax.lax.broadcast_shapes(
        self._loc.shape, self._scale.shape)
Example #4
0
    def __init__(self, loc: Numeric, scale: Numeric) -> None:
        """Initializes a Logistic distribution.

    Args:
      loc: Mean of the distribution.
      scale: Spread of the distribution.
    """
        super().__init__()
        self._loc = conversion.as_float_array(loc)
        self._scale = conversion.as_float_array(scale)
        self._batch_shape = jax.lax.broadcast_shapes(self._loc.shape,
                                                     self._scale.shape)
Example #5
0
    def __init__(self,
                 loc: Optional[Array] = None,
                 scale_diag: Optional[Array] = None):
        """Initializes a MultivariateNormalDiag distribution.

    Args:
      loc: Mean vector of the distribution. Can also be a batch of vectors. If
        not specified, it defaults to zeros. At least one of `loc` and
        `scale_diag` must be specified.
      scale_diag: Vector of standard deviations. Can also be a batch of vectors.
        If not specified, it defaults to ones. At least one of `loc` and
        `scale_diag` must be specified.
    """
        super().__init__()
        chex.assert_not_both_none(loc, scale_diag)
        if scale_diag is not None and not scale_diag.shape:
            raise ValueError(
                'If provided, argument `scale_diag` must have at least '
                '1 dimension.')
        if loc is not None and not loc.shape:
            raise ValueError('If provided, argument `loc` must have at least '
                             '1 dimension.')
        if loc is not None and scale_diag is not None and (
                loc.shape[-1] != scale_diag.shape[-1]):
            raise ValueError(
                f'The last dimension of arguments `loc` and '
                f'`scale_diag` must coincide, but {loc.shape[-1]} != '
                f'{scale_diag.shape[-1]}.')

        if scale_diag is None:
            self._loc = conversion.as_float_array(loc)
            self._scale_diag = jnp.ones(self._loc.shape[-1], self._loc.dtype)
        elif loc is None:
            self._scale_diag = conversion.as_float_array(scale_diag)
            self._loc = jnp.zeros(self._scale_diag.shape[-1],
                                  self._scale_diag.dtype)
        else:
            self._loc = conversion.as_float_array(loc)
            self._scale_diag = conversion.as_float_array(scale_diag)

        self._batch_shape = jax.lax.broadcast_shapes(
            self._loc.shape[:-1], self._scale_diag.shape[:-1])
Example #6
0
    def __init__(self,
                 loc: Numeric,
                 log_scale: Numeric,
                 max_scale: Optional[float] = None):
        """Initializes a LogStddevNormal distribution.

    Args:
      loc: Mean of the distribution.
      log_scale: Log of the distribution's scale. This is often the
        pre-activated output of a neural network.
      max_scale: Maximum value of the scale that this distribution will saturate
        at. This parameter can be useful for numerical stability. It is not a
        hard maximum; rather, we compute scale as per the following formula:
        log(max_scale) - softplus(log(max_scale) - log_scale).
    """
        if max_scale is not None:
            max_log_scale = math.log(max_scale)
            self._log_scale = max_log_scale - jax.nn.softplus(
                max_log_scale - conversion.as_float_array(log_scale))
        else:
            self._log_scale = conversion.as_float_array(log_scale)
        scale = jnp.exp(self._log_scale)
        super().__init__(loc, scale)
Example #7
0
 def test_on_invalid_array(self, dtype):
     x = jnp.zeros([], dtype)
     with self.assertRaises(ValueError):
         conversion.as_float_array(x)
Example #8
0
 def test_on_int_array(self, dtype):
     x = jnp.zeros([], dtype)
     y = conversion.as_float_array(x)
     self.assertIsInstance(y, jnp.ndarray)
     self.assertEqual(
         y.dtype, jnp.float64 if jax.config.x64_enabled else jnp.float32)
Example #9
0
 def test_on_float_array(self, dtype):
     x = jnp.zeros([], dtype)
     y = conversion.as_float_array(x)
     self.assertIs(y, x)
Example #10
0
 def test_on_invalid_scalar(self, x):
     with self.assertRaises(ValueError):
         conversion.as_float_array(x)
Example #11
0
 def test_on_valid_scalar(self, x):
     y = conversion.as_float_array(x)
     self.assertIsInstance(y, jnp.ndarray)
     self.assertEqual(
         y.dtype, jnp.float64 if jax.config.x64_enabled else jnp.float32)