def bernoulli(key: np.ndarray, p: np.ndarray = onp.float32(0.5), shape: Optional[Sequence[int]] = None) -> np.ndarray: """Sample Bernoulli random values with given shape and mean. Args: key: a PRNGKey used as the random key. p: optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with ``shape``. Default 0.5. shape: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ dtype = dtypes.canonicalize_dtype(lax.dtype(p)) if shape is not None: shape = abstract_arrays.canonicalize_shape(shape) if not np.issubdtype(dtype, onp.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) return _bernoulli(key, p, shape)
def beta(key: np.ndarray, a: Union[float, np.ndarray], b: Union[float, np.ndarray], shape: Optional[Sequence[int]] = None, dtype: onp.dtype = onp.float64) -> np.ndarray: """Sample Beta random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. a: a float or array of floats broadcast-compatible with ``shape`` representing the first parameter "alpha". b: a float or array of floats broadcast-compatible with ``shape`` representing the second parameter "beta". shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``a`` and ``b``. The default (None) produces a result shape by broadcasting ``a`` and ``b``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = abstract_arrays.canonicalize_shape(shape) return _beta(key, a, b, shape, dtype)
def multivariate_normal(key: np.ndarray, mean: np.ndarray, cov: np.ndarray, shape: Optional[Sequence[int]] = None, dtype: onp.dtype = onp.float64) -> np.ndarray: """Sample multivariate normal random values with given mean and covariance. Args: key: a PRNGKey used as the random key. mean: a mean vector of shape ``(..., n)``. cov: a positive definite covariance matrix of shape ``(..., n, n)``. The batch shape ``...`` must be broadcast-compatible with that of ``mean``. shape: optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``. The default (None) produces a result batch shape by broadcasting together the batch shapes of ``mean`` and ``cov``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = abstract_arrays.canonicalize_shape(shape) return _multivariate_normal(key, mean, cov, shape, dtype)
def truncated_normal(key: np.ndarray, lower: Union[float, np.ndarray], upper: Union[float, np.ndarray], shape: Optional[Sequence[int]] = None, dtype: onp.dtype = onp.float64) -> np.ndarray: """Sample truncated standard normal random values with given shape and dtype. Args: key: a PRNGKey used as the random key. lower: a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with ``upper``. upper: a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with ``lower``. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``lower`` and ``upper``. The default (None) produces a result shape by broadcasting ``lower`` and ``upper``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. """ dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = abstract_arrays.canonicalize_shape(shape) return _truncated_normal(key, lower, upper, shape, dtype)
def _check_shape(name, shape, *param_shapes): shape = abstract_arrays.canonicalize_shape(shape) if param_shapes: shape_ = lax.broadcast_shapes(shape, *param_shapes) if shape != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.") raise ValueError(msg.format(name, shape_, shape))
def exponential(key, shape=(), dtype=onp.float64): """Sample Exponential random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the result shape. Default (). dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ dtype = dtypes.canonicalize_dtype(dtype) shape = abstract_arrays.canonicalize_shape(shape) return _exponential(key, shape, dtype)
def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the result shape. Default (). dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). Returns: A random array with the specified shape and dtype. """ dtype = dtypes.canonicalize_dtype(dtype) shape = abstract_arrays.canonicalize_shape(shape) return _uniform(key, shape, dtype, minval, maxval)
def randint(key, shape, minval, maxval, dtype=onp.int64): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. minval: int or array of ints broadcast-compatible with ``shape``, a minimum (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. dtype: optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32). Returns: A random array with the specified shape and dtype. """ dtype = dtypes.canonicalize_dtype(dtype) shape = abstract_arrays.canonicalize_shape(shape) return _randint(key, shape, minval, maxval, dtype)
def t(key, df, shape=(), dtype=onp.float64): """Sample Student's t random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. df: a float or array of floats broadcast-compatible with ``shape`` representing the parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``df``. The default (None) produces a result shape equal to ``df.shape``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ dtype = dtypes.canonicalize_dtype(dtype) shape = abstract_arrays.canonicalize_shape(shape) return _t(key, df, shape, dtype)
def poisson(key, lam, shape=(), dtype=np.int64): """Sample Poisson random values with given shape and integer dtype. Args: key: a PRNGKey used as the random key. lam: rate parameter (mean of the distribution), must be >= 0. shape: optional, a tuple of nonnegative integers representing the result shape. Default (). dtype: optional, a integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32). Returns: A random array with the specified shape and dtype. """ dtype = dtypes.canonicalize_dtype(dtype) shape = abstract_arrays.canonicalize_shape(shape) if np.shape(lam) != shape: lam = jnp.broadcast_to(lam, shape) lam = lam.astype(np.float32) return _poisson(key, lam, shape, dtype)
def dirichlet(key, alpha, shape=None, dtype=onp.float64): """Sample Dirichlet random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. alpha: an array of shape ``(..., n)`` used as the concentration parameter of the random variables. shape: optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value ``n``. Must be broadcast-compatible with ``alpha.shape[:-1]``. The default (None) produces a result shape equal to ``alpha.shape``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified dtype and shape given by ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else ``alpha.shape``. """ dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = abstract_arrays.canonicalize_shape(shape) return _dirichlet(key, alpha, shape, dtype)