Esempio n. 1
0
def test_ravel_pytree(pytree):
    flat, unravel_fn = ravel_pytree(pytree)
    unravel = unravel_fn(flat)
    tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree))
    assert all(tree_flatten(tree_multimap(lambda x, y:
                                          canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)),
                                          unravel, pytree))[0])
Esempio n. 2
0
def _uniform(key, shape, dtype, minval, maxval):
    if not onp.issubdtype(dtype, onp.floating):
        raise TypeError("uniform only accepts floating point dtypes.")

    dtype = xla_bridge.canonicalize_dtype(dtype)
    minval = lax.convert_element_type(minval, dtype)
    maxval = lax.convert_element_type(maxval, dtype)
    finfo = onp.finfo(dtype)
    nbits, nmant = finfo.bits, finfo.nmant

    if nbits not in (32, 64):
        raise TypeError("uniform only accepts 32- or 64-bit dtypes.")

    bits = _random_bits(key, nbits, shape)

    # The strategy here is to randomize only the mantissa bits with an exponent of
    # 1 (after applying the bias), then shift and scale to the desired range. The
    # bit-level transformation we use relies on Numpy and XLA having bit-for-bit
    # equivalent float representations, which might not be true on all platforms.
    float_bits = lax.bitwise_or(
        lax.shift_right_logical(bits, onp.array(nbits - nmant,
                                                lax.dtype(bits))),
        onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
    floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
    return lax.max(minval,
                   lax.reshape(floats * (maxval - minval) + minval, shape))
Esempio n. 3
0
def prepare_single_layer_model(input_size, output_size, width, key):
    init_random_params, predict = stax.serial(Dense(width), Relu,
                                              Dense(output_size), LogSoftmax)

    key, split = random.split(key)
    _, params = init_random_params(split, (-1, input_size))

    cast = lambda x: x.astype(canonicalize_dtype(onp.float64))
    params = tree_util.tree_map(cast, params)
    return predict, params, key
Esempio n. 4
0
 def __init__(self, value=0., log_density=0., event_ndim=0, validate_args=None):
     if event_ndim > np.ndim(value):
         raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'
                          .format(event_ndim, np.ndim(value)))
     batch_dim = np.ndim(value) - event_ndim
     batch_shape = np.shape(value)[:batch_dim]
     event_shape = np.shape(value)[batch_dim:]
     self.value = lax.convert_element_type(value, xla_bridge.canonicalize_dtype(np.float64))
     # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
     self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
     super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Esempio n. 5
0
def get_batch(input_size, output_size, batch_size, key):
    key, split = random.split(key)
    # jax.random will always generate float32 even if jax_enable_x64==True.
    xs = random.normal(split,
                       shape=(batch_size, input_size),
                       dtype=canonicalize_dtype(onp.float64))
    key, split = random.split(key)
    ys = random.randint(split,
                        minval=0,
                        maxval=output_size,
                        shape=(batch_size, ))
    ys = to_onehot(ys, output_size)
    return (xs, ys), key
Esempio n. 6
0
def randint(key, shape, minval, maxval, dtype=onp.int32):
    """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: optional, a minimum (inclusive) value for the range (default 0).
    maxval: optional, a maximum (exclusive) value for the range (default 1).
    dtype: optional, an int dtype for the returned values (default int32).

  Returns:
    A random array with the specified shape and dtype.
  """
    if not onp.issubdtype(dtype, onp.integer):
        raise TypeError("randint only accepts integer dtypes.")

    dtype = xla_bridge.canonicalize_dtype(dtype)
    minval = lax.convert_element_type(minval, dtype)
    maxval = lax.convert_element_type(maxval, dtype)
    nbits = onp.iinfo(dtype).bits

    if nbits not in (32, 64):
        raise TypeError("randint only accepts 32- or 64-bit dtypes.")

    # if we don't have minval < maxval, just always return minval
    # https://github.com/google/jax/issues/222
    maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval)

    # This algorithm is biased whenever (maxval - minval) is not a power of 2.
    # We generate double the number of random bits required by the dtype so as to
    # reduce that bias.
    k1, k2 = split(key)
    rbits = lambda key: _random_bits(key, nbits, shape)
    higher_bits, lower_bits = rbits(k1), rbits(k2)

    unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64
    span = lax.convert_element_type(maxval - minval, unsigned_dtype)

    # To compute a remainder operation on an integer that might have twice as many
    # bits as we can represent in the native unsigned dtype, we compute a
    # multiplier equal to 2**nbits % span (using that nbits is 32 or 64).
    multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span)
    multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
    if nbits == 64:
        multiplier = lax.rem(lax.mul(multiplier, multiplier), span)

    random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
                            lax.rem(lower_bits, span))
    random_offset = lax.rem(random_offset, span)
    return lax.add(minval, lax.convert_element_type(random_offset, dtype))
Esempio n. 7
0
def normal(key, shape=(), dtype=onp.float64):
  """Sample standard normal random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _normal(key, shape, dtype)
Esempio n. 8
0
def _ravel_list(*leaves):
    leaves_metadata = tree_map(
        lambda l: pytree_metadata(np.ravel(l), np.shape(l), np.size(l),
                                  canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = np.cumsum(
        np.array((0, ) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [
            np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                       m.shape).astype(m.dtype)
            for i, m in enumerate(leaves_metadata)
        ]

    return np.concatenate([m.flat for m in leaves_metadata]), unravel_list
Esempio n. 9
0
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = onp.complex64 if onp.finfo(operand.dtype).bits == 32 else onp.complex128
    dtype = xb.canonicalize_dtype(dtype)
    vl = vr = ShapedArray(batch_dims + (n, n), dtype)
    w = ShapedArray(batch_dims + (n,), dtype)
  else:
    raise NotImplementedError
  return w, vl, vr
Esempio n. 10
0
def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    minval: optional, a minimum (inclusive) value for the range (default 0).
    maxval: optional, a maximum (exclusive) value for the range (default 1).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _uniform(key, shape, dtype, minval, maxval)
Esempio n. 11
0
def t(key, df, shape=(), dtype=onp.float64):
  """Sample Student's t random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    df: an array-like broadcastable to `shape` and used as the shape parameter
      of the random variables.
    shape: optional, a tuple of nonnegative integers representing the shape
      (default scalar).
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _t(key, df, shape, dtype)
Esempio n. 12
0
def dirichlet(key, alpha, shape=(), dtype=onp.float64):
  """Sample Cauchy random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    alpha: an array-like with `alpha.shape[:-1]` broadcastable to `shape` and
      used as the concentration parameter of the random variables.
    shape: optional, a tuple of nonnegative integers representing the batch
      shape (defaults to `alpha.shape[:-1]`).
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _dirichlet(key, alpha, shape, dtype)
Esempio n. 13
0
def randint(key, shape, minval, maxval, dtype=onp.int64):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with  ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int64 if
      jax_enable_x64 is true, otherwise int32).

  Returns:
    A random array with the specified shape and dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _randint(key, shape, minval, maxval, dtype)
Esempio n. 14
0
def bernoulli(key, p=onp.float32(0.5), shape=()):
  """Sample Bernoulli random values with given shape and mean.

  Args:
    key: a PRNGKey used as the random key.
    p: optional, an array-like of floating dtype broadcastable to `shape` for
      the mean of the random variables (default 0.5).
    shape: optional, a tuple of nonnegative integers representing the shape
      (default scalar).

  Returns:
    A random array with the specified shape and boolean dtype.
  """
  dtype = xla_bridge.canonicalize_dtype(lax.dtype(p))
  if not onp.issubdtype(dtype, onp.floating):
    msg = "bernoulli probability `p` must have a floating dtype, got {}."
    raise TypeError(msg.format(dtype))
  p = lax.convert_element_type(p, dtype)
  return _bernoulli(key, p, shape)
Esempio n. 15
0
def t(key, df, shape=(), dtype=onp.float64):
  """Sample Student's t random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    df: a float or array of floats broadcast-compatible with ``shape``
      representing the parameter of the distribution.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``df``. The default (None)
      produces a result shape equal to ``df.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and with shape given by ``shape`` if
    ``shape`` is not None, or else by ``df.shape``.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _t(key, df, shape, dtype)
Esempio n. 16
0
def bernoulli(key, p=onp.float32(0.5), shape=None):
  """Sample Bernoulli random values with given shape and mean.

  Args:
    key: a PRNGKey used as the random key.
    p: optional, a float or array of floats for the mean of the random
      variables. Must be broadcast-compatible with ``shape``. Default 0.5.
    shape: optional, a tuple of nonnegative integers representing the result
      shape. Must be broadcast-compatible with ``p.shape``. The default (None)
      produces a result shape equal to ``p.shape``.

  Returns:
    A random array with boolean dtype and shape given by ``shape`` if ``shape``
    is not None, or else ``p.shape``.
  """
  dtype = xla_bridge.canonicalize_dtype(lax.dtype(p))
  if not onp.issubdtype(dtype, onp.floating):
    msg = "bernoulli probability `p` must have a floating dtype, got {}."
    raise TypeError(msg.format(dtype))
  p = lax.convert_element_type(p, dtype)
  return _bernoulli(key, p, shape)
Esempio n. 17
0
def beta(key, a, b, shape=None, dtype=onp.float64):
  """Sample Bernoulli random values with given shape and mean.

  Args:
    key: a PRNGKey used as the random key.
    a: a float or array of floats broadcast-compatible with ``shape``
      representing the first parameter "alpha".
    b: a float or array of floats broadcast-compatible with ``shape``
      representing the second parameter "beta".
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``a`` and ``b``. The default
      (None) produces a result shape by broadcasting ``a`` and ``b``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by ``shape`` if
    ``shape`` is not None, or else by broadcasting ``a`` and ``b``.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _beta(key, a, b, shape, dtype)
Esempio n. 18
0
def _randint(key, shape, minval, maxval, dtype=onp.int32):
    _check_shape("randint", shape)
    if not onp.issubdtype(dtype, onp.integer):
        raise TypeError("randint only accepts integer dtypes.")

    dtype = xla_bridge.canonicalize_dtype(dtype)
    minval = lax.convert_element_type(minval, dtype)
    maxval = lax.convert_element_type(maxval, dtype)
    nbits = onp.iinfo(dtype).bits

    if nbits not in (32, 64):
        raise TypeError("randint only accepts 32- or 64-bit dtypes.")

    # if we don't have minval < maxval, just always return minval
    # https://github.com/google/jax/issues/222
    maxval = lax.max(lax.add(minval, onp.array(1, dtype)), maxval)

    # This algorithm is biased whenever (maxval - minval) is not a power of 2.
    # We generate double the number of random bits required by the dtype so as to
    # reduce that bias.
    k1, k2 = split(key)
    rbits = lambda key: _random_bits(key, nbits, shape)
    higher_bits, lower_bits = rbits(k1), rbits(k2)

    unsigned_dtype = onp.uint32 if nbits == 32 else onp.uint64
    span = lax.convert_element_type(maxval - minval, unsigned_dtype)

    # To compute a remainder operation on an integer that might have twice as many
    # bits as we can represent in the native unsigned dtype, we compute a
    # multiplier equal to 2**nbits % span (using that nbits is 32 or 64).
    multiplier = lax.rem(onp.array(2**16, unsigned_dtype), span)
    multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
    if nbits == 64:
        multiplier = lax.rem(lax.mul(multiplier, multiplier), span)

    random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
                            lax.rem(lower_bits, span))
    random_offset = lax.rem(random_offset, span)
    return lax.add(minval, lax.convert_element_type(random_offset, dtype))
Esempio n. 19
0
def dirichlet(key, alpha, shape=None, dtype=onp.float64):
  """Sample Cauchy random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    alpha: an array of shape ``(..., n)`` used as the concentration
      parameter of the random variables.
    shape: optional, a tuple of nonnegative integers specifying the result
      batch shape; that is, the prefix of the result shape excluding the last
      element of value ``n``. Must be broadcast-compatible with
      ``alpha.shape[:-1]``. The default (None) produces a result shape equal to
      ``alpha.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by
    ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
    ``alpha.shape``.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _dirichlet(key, alpha, shape, dtype)
Esempio n. 20
0
def truncated_normal(key, lower, upper, shape=None, dtype=onp.float64):
  """Sample truncated standard normal random values with given shape and dtype.

  Args:
    key: a PRNGKey used as the random key.
    lower: a float or array of floats representing the lower bound for
      truncation. Must be broadcast-compatible with ``upper``.
    upper: a float or array of floats representing the  upper bound for
      truncation. Must be broadcast-compatible with ``lower``.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
      default (None) produces a result shape by broadcasting ``lower`` and
      ``upper``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by ``shape`` if
    ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _truncated_normal(key, lower, upper, shape, dtype)
Esempio n. 21
0
def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
    """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    dtype: optional, a float dtype for the returned values (default float32).
    minval: optional, a minimum (inclusive) value for the range (default 0).
    maxval: optional, a maximum (exclusive) value for the range (default 1).

  Returns:
    A random array with the specified shape and dtype.
  """
    if not onp.issubdtype(dtype, onp.floating):
        raise TypeError("uniform only accepts floating point dtypes.")

    dtype = xla_bridge.canonicalize_dtype(dtype)
    minval = lax.convert_element_type(minval, dtype)
    maxval = lax.convert_element_type(maxval, dtype)
    finfo = onp.finfo(dtype)
    nbits, nmant = finfo.bits, finfo.nmant

    if nbits not in (32, 64):
        raise TypeError("uniform only accepts 32- or 64-bit dtypes.")

    bits = _random_bits(key, nbits, shape)

    # The strategy here is to randomize only the mantissa bits with an exponent of
    # 1 (after applying the bias), then shift and scale to the desired range. The
    # bit-level transformation we use relies on Numpy and XLA having bit-for-bit
    # equivalent float representations, which might not be true on all platforms.
    float_bits = lax.bitwise_or(
        lax.shift_right_logical(bits, onp.array(nbits - nmant,
                                                lax._dtype(bits))),
        onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
    floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
    return lax.max(minval,
                   lax.reshape(floats * (maxval - minval) + minval, shape))
Esempio n. 22
0
def multivariate_normal(key, mean, cov, shape=None, dtype=onp.float64):
  """Sample multivariate normal random values with given mean and covariance.

  Args:
    key: a PRNGKey used as the random key.
    mean: a mean vector of shape ``(..., n)``.
    cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
      batch shape ``...`` must be broadcast-compatible with that of ``mean``.
    shape: optional, a tuple of nonnegative integers specifying the result
      batch shape; that is, the prefix of the result shape excluding the last
      axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
      ``cov.shape[:-2]``. The default (None) produces a result batch shape by
      broadcasting together the batch shapes of ``mean`` and ``cov``.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).

  Returns:
    A random array with the specified dtype and shape given by
    ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
    ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
  """
  dtype = xla_bridge.canonicalize_dtype(dtype)
  return _multivariate_normal(key, mean, cov, shape, dtype)
Esempio n. 23
0
def float_types():
    return set(
        onp.dtype(xla_bridge.canonicalize_dtype(dtype))
        for dtype in [onp.float32, onp.float64])
Esempio n. 24
0
def get_dtype(x):
    return canonicalize_dtype(lax.dtype(x))
Esempio n. 25
0
def poisson(key, rate, shape, dtype=np.int64):
    dtype = canonicalize_dtype(dtype)
    return _poisson(key, rate, shape, dtype)
Esempio n. 26
0
def complex_types():
  return sorted(list({onp.dtype(xla_bridge.canonicalize_dtype(dtype))
                     for dtype in [onp.complex64, onp.complex128]}))
Esempio n. 27
0
def standard_gamma(key, alpha, shape=(), dtype=np.float64):
    dtype = xla_bridge.canonicalize_dtype(dtype)
    return _standard_gamma(key, alpha, shape, dtype)
Esempio n. 28
0
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk):
    """Transforms (batches of) inputs to a `Kernel`.

  This is a private method. Docstring and example are for internal reference.

   The kernel contains the empirical covariances between different inputs and
     their entries (pixels) necessary to compute the covariance of the Gaussian
     Process corresponding to an infinite Bayesian or gradient-flow-trained
     neural network.

   The smallest necessary number of covariance entries is tracked. For example,
     all networks are assumed to have i.i.d. weights along the channel / feature
     / logits dimensions, hence covariance between different entries along these
     dimensions is known to be 0 and is not tracked.

  Args:
    x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense
      network) or 4D of shape `[batch_size_1, height, width, channels]`
      (conv-nets).
    x2: an optional `np.ndarray` with the same shape as `x1` apart
      from possibly different leading batch size. `None` means
      `x2 == x1`.
    use_pooling: a boolean, indicating whether pooling will be used somewhere in
      the model. If so, more covariance entries need to be tracked. Is set
      automatically based on the network topology. Specifically, is set to
      `False` if a `serial` or `parallel` networks contain a `Flatten` layer
      and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for
      non-convolutional models.
    compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels,
        `False` to only compute NNGP.

    Example:
      ```python
          >>> x = np.ones((10, 32, 16, 3))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 32, 16, 16)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 16)
          >>> x1 = np.ones((10, 128))
          >>> x2 = np.ones((20, 128))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).ntk
          None
      ```

  Returns:
    a `Kernel` object.
  """
    x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
    var1 = _get_variance(x1)

    if x2 is None:
        x2 = x1
        var2 = None
    else:
        if x1.shape[1:] != x2.shape[1:]:
            raise ValueError(
                '`x1` and `x2` are expected to be batches of'
                ' inputs with the same shape (apart from the batch size),'
                ' got %s and %s.' % (str(x1.shape), str(x2.shape)))

        x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64))
        var2 = _get_variance(x2)

    if use_pooling and x1.ndim == 4:
        x2 = np.expand_dims(x2, -1)
        nngp = np.dot(x1, x2) / x1.shape[-1]
        nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5))

    elif x1.ndim == 4 or x1.ndim == 2:
        nngp = _batch_uncentered_covariance(x1, x2)

    else:
        raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape '
                         '`[batch_size, n_features]` or '
                         '`[batch_size, height, width, channels]`, '
                         'got %s.' % str(x1.shape))

    ntk = 0. if compute_ntk else None
    is_gaussian = False
    is_height_width = True
    return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
Esempio n. 29
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2*math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`~numpyro.mcmc.HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param jax.random.PRNGKey rng: random key to be used as the source of
            randomness.
        """
        step_size = lax.convert_element_type(step_size, xla_bridge.canonicalize_dtype(np.float64))
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = trajectory_length
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size,
                                     potential_fn, kinetic_fn,
                                     momentum_generator)

        wa_init, wa_update = warmup_adapter(num_warmup,
                                            adapt_step_size=adapt_step_size,
                                            adapt_mass_matrix=adapt_mass_matrix,
                                            dense_mass=dense_mass,
                                            target_accept_prob=target_accept_prob,
                                            find_reasonable_step_size=find_reasonable_ss)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
                             False, wa_state, rng_hmc)

        # TODO: Remove; this should be the responsibility of the MCMC class.
        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = jit(sample_kernel)(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
        return hmc_state
Esempio n. 30
0
def _get_num_steps(step_size, trajectory_length):
    num_steps = np.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to np.int64 does not take effect (returns np.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(xla_bridge.canonicalize_dtype(np.int64))