Ejemplo n.º 1
0
def bernoulli(key: np.ndarray,
              p: np.ndarray = onp.float32(0.5),
              shape: Optional[Sequence[int]] = None) -> np.ndarray:
    """Sample Bernoulli random values with given shape and mean.

  Args:
    key: a PRNGKey used as the random key.
    p: optional, a float or array of floats for the mean of the random
      variables. Must be broadcast-compatible with ``shape``. Default 0.5.
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Must be broadcast-compatible with ``p.shape``. The default (None)
      produces a result shape equal to ``p.shape``.

  Returns:
    A random array with boolean dtype and shape given by ``shape`` if ``shape``
    is not None, or else ``p.shape``.
  """
    dtype = dtypes.canonicalize_dtype(lax.dtype(p))
    if shape is not None:
        shape = abstract_arrays.canonicalize_shape(shape)
    if not np.issubdtype(dtype, onp.floating):
        msg = "bernoulli probability `p` must have a floating dtype, got {}."
        raise TypeError(msg.format(dtype))
    p = lax.convert_element_type(p, dtype)
    return _bernoulli(key, p, shape)
Ejemplo n.º 2
0
def beta(key: np.ndarray,
         a: Union[float, np.ndarray],
         b: Union[float, np.ndarray],
         shape: Optional[Sequence[int]] = None,
         dtype: onp.dtype = onp.float64) -> np.ndarray:
    """Sample Beta random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    a: a float or array of floats broadcast-compatible with ``shape``
      representing the first parameter "alpha".
    b: a float or array of floats broadcast-compatible with ``shape``
      representing the second parameter "beta".
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``a`` and ``b``. The default
      (None) produces a result shape by broadcasting ``a`` and ``b``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by ``shape`` if
    ``shape`` is not None, or else by broadcasting ``a`` and ``b``.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    if shape is not None:
        shape = abstract_arrays.canonicalize_shape(shape)
    return _beta(key, a, b, shape, dtype)
Ejemplo n.º 3
0
def multivariate_normal(key: np.ndarray,
                        mean: np.ndarray,
                        cov: np.ndarray,
                        shape: Optional[Sequence[int]] = None,
                        dtype: onp.dtype = onp.float64) -> np.ndarray:
    """Sample multivariate normal random values with given mean and covariance.

  Args:
    key: a PRNGKey used as the random key.
    mean: a mean vector of shape ``(..., n)``.
    cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
      batch shape ``...`` must be broadcast-compatible with that of ``mean``.
    shape: optional, a tuple of nonnegative integers specifying the result
      batch shape; that is, the prefix of the result shape excluding the last
      axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
      ``cov.shape[:-2]``. The default (None) produces a result batch shape by
      broadcasting together the batch shapes of ``mean`` and ``cov``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by
    ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
    ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    if shape is not None:
        shape = abstract_arrays.canonicalize_shape(shape)
    return _multivariate_normal(key, mean, cov, shape, dtype)
Ejemplo n.º 4
0
def truncated_normal(key: np.ndarray,
                     lower: Union[float, np.ndarray],
                     upper: Union[float, np.ndarray],
                     shape: Optional[Sequence[int]] = None,
                     dtype: onp.dtype = onp.float64) -> np.ndarray:
    """Sample truncated standard normal random values with given shape and dtype.

  Args:
    key: a PRNGKey used as the random key.
    lower: a float or array of floats representing the lower bound for
      truncation. Must be broadcast-compatible with ``upper``.
    upper: a float or array of floats representing the  upper bound for
      truncation. Must be broadcast-compatible with ``lower``.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
      default (None) produces a result shape by broadcasting ``lower`` and
      ``upper``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by ``shape`` if
    ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    if shape is not None:
        shape = abstract_arrays.canonicalize_shape(shape)
    return _truncated_normal(key, lower, upper, shape, dtype)
Ejemplo n.º 5
0
def _check_shape(name, shape, *param_shapes):
  shape = abstract_arrays.canonicalize_shape(shape)

  if param_shapes:
    shape_ = lax.broadcast_shapes(shape, *param_shapes)
    if shape != shape_:
      msg = ("{} parameter shapes must be broadcast-compatible with shape "
             "argument, and the result of broadcasting the shapes must equal "
             "the shape argument, but got result {} for shape argument {}.")
      raise ValueError(msg.format(name, shape_, shape))
Ejemplo n.º 6
0
def exponential(key, shape=(), dtype=onp.float64):
    """Sample Exponential random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Default ().
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified shape and dtype.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    shape = abstract_arrays.canonicalize_shape(shape)
    return _exponential(key, shape, dtype)
Ejemplo n.º 7
0
def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.):
    """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Default ().
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    minval: optional, a minimum (inclusive) value for the range (default 0).
    maxval: optional, a maximum (exclusive) value for the range (default 1).

  Returns:
    A random array with the specified shape and dtype.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    shape = abstract_arrays.canonicalize_shape(shape)
    return _uniform(key, shape, dtype, minval, maxval)
Ejemplo n.º 8
0
def randint(key, shape, minval, maxval, dtype=onp.int64):
    """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int64 if
      jax_enable_x64 is true, otherwise int32).

  Returns:
    A random array with the specified shape and dtype.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    shape = abstract_arrays.canonicalize_shape(shape)
    return _randint(key, shape, minval, maxval, dtype)
Ejemplo n.º 9
0
def t(key, df, shape=(), dtype=onp.float64):
    """Sample Student's t random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    df: a float or array of floats broadcast-compatible with ``shape``
      representing the parameter of the distribution.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``df``. The default (None)
      produces a result shape equal to ``df.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and with shape given by ``shape`` if
    ``shape`` is not None, or else by ``df.shape``.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    shape = abstract_arrays.canonicalize_shape(shape)
    return _t(key, df, shape, dtype)
Ejemplo n.º 10
0
def poisson(key, lam, shape=(), dtype=np.int64):
  """Sample Poisson random values with given shape and integer dtype.

  Args:
    key: a PRNGKey used as the random key.
    lam: rate parameter (mean of the distribution), must be >= 0.
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Default ().
    dtype: optional, a integer dtype for the returned values (default int64 if
      jax_enable_x64 is true, otherwise int32).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = dtypes.canonicalize_dtype(dtype)
  shape = abstract_arrays.canonicalize_shape(shape)
  if np.shape(lam) != shape:
    lam = jnp.broadcast_to(lam, shape)
  lam = lam.astype(np.float32)
  return _poisson(key, lam, shape, dtype)
Ejemplo n.º 11
0
def dirichlet(key, alpha, shape=None, dtype=onp.float64):
    """Sample Dirichlet random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    alpha: an array of shape ``(..., n)`` used as the concentration
      parameter of the random variables.
    shape: optional, a tuple of nonnegative integers specifying the result
      batch shape; that is, the prefix of the result shape excluding the last
      element of value ``n``. Must be broadcast-compatible with
      ``alpha.shape[:-1]``. The default (None) produces a result shape equal to
      ``alpha.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by
    ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
    ``alpha.shape``.
  """
    dtype = dtypes.canonicalize_dtype(dtype)
    if shape is not None:
        shape = abstract_arrays.canonicalize_shape(shape)
    return _dirichlet(key, alpha, shape, dtype)