Exemplo n.º 1
0
def type_assert(inputs: Union[ArrayLike, List[ArrayLike]],
                expected_types: Union[Type[Scalar], List[Type[Scalar]]]):
    """Checks that the type of all inputs matches specified expected_types.

  Args:
    inputs: list of inputs.
    expected_types: list of expected types associated with each input; if all
      inputs have same type, a single type may be passed as `expected_types`.

  Raises:
    ValueError: if the length of inputs and expected_types do not match.
  """
    if not isinstance(inputs, list):
        inputs = [inputs]
    if not isinstance(expected_types, list):
        expected_types = [expected_types] * len(inputs)
    if len(inputs) != len(expected_types):
        raise ValueError("Length of inputs and expected_types must match.")
    for x, expected in zip(inputs, expected_types):
        if jnp.issubdtype(expected, jnp.floating):
            parent = jnp.floating
        elif jnp.issubdtype(expected, jnp.integer):
            parent = jnp.integer
        else:
            raise ValueError(
                "Error in type compatibility check, unsupported dtype"
                " {}".format(expected))

        if not jnp.issubdtype(jnp.result_type(x), parent):
            raise ValueError("Error in type compatibility check, found {} but "
                             "expected {}.".format(jnp.result_type(x),
                                                   expected))
Exemplo n.º 2
0
def assert_type(inputs: Union[Scalar, Union[Array, Sequence[Array]]],
                expected_types: Union[Type[Scalar], Sequence[Type[Scalar]]]):
    """Checks that the type of all `inputs` matches specified `expected_types`.

  Valid usages include:

  ```
    assert_type(7, int)
    assert_type(7.1, float)
    assert_type(False, bool)
    assert_type([7, 8], int)
    assert_type([7, 7.1], [int, float])
    assert_type(np.array(7), int)
    assert_type(np.array(7.1), float)
    assert_type(jnp.array(7), int)
    assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float])
  ```

  Args:
    inputs: array or sequence of arrays or scalars.
    expected_types: sequence of expected types associated with each input; if
      all inputs have same type, a single type may be passed as
      `expected_types`.

  Raises:
    AssertionError: if the length of `inputs` and `expected_types` don't match;
                    if `expected_types` contains unsupported pytype;
                    if the types of input do not match the expected types.
  """
    if not isinstance(inputs, (list, tuple)):
        inputs = [inputs]
    if not isinstance(expected_types, (list, tuple)):
        expected_types = [expected_types] * len(inputs)

    errors = []
    if len(inputs) != len(expected_types):
        raise AssertionError(
            f"Length of `inputs` and `expected_types` must match, "
            f"got {len(inputs)} != {len(expected_types)}.")
    for idx, (x, expected) in enumerate(zip(inputs, expected_types)):
        if jnp.issubdtype(expected, jnp.floating):
            parent = jnp.floating
        elif jnp.issubdtype(expected, jnp.integer):
            parent = jnp.integer
        elif jnp.issubdtype(expected, jnp.bool_):
            parent = jnp.bool_
        else:
            raise AssertionError(
                f"Error in type compatibility check, unsupported dtype '{expected}'."
            )

        if not jnp.issubdtype(jnp.result_type(x), parent):
            errors.append((idx, jnp.result_type(x), expected))

    if errors:
        msg = "; ".join("input {} has type {} but expected {}".format(*err)
                        for err in errors)

        raise AssertionError("Error in type compatibility check: " + msg + ".")
Exemplo n.º 3
0
def test_ravel_pytree(pytree):
    flat, unravel_fn = ravel_pytree(pytree)
    unravel = unravel_fn(flat)
    tree_flatten(
        tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree))
    assert all(
        tree_flatten(
            tree_multimap(
                lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel,
                pytree))[0])
Exemplo n.º 4
0
    def scan_fn(broadcast_in, init, *args):
        xs = jax.tree_multimap(transpose_to_front, in_axes, args)

        def body_fn(c, xs, init_mode=False):
            # inject constants
            xs = jax.tree_multimap(
                lambda ax, arg, x: (arg if ax is broadcast else x), in_axes,
                args, xs)
            broadcast_out, c, ys = fn(broadcast_in, c, *xs)

            if init_mode:
                ys = jax.tree_multimap(
                    lambda ax, y: (y if ax is broadcast else ()), out_axes, ys)
                return broadcast_out, ys
            else:
                ys = jax.tree_multimap(
                    lambda ax, y: (() if ax is broadcast else y), out_axes, ys)
                return c, ys

        broadcast_body = functools.partial(body_fn, init_mode=True)

        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)
        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)
        broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)

        c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse)
        ys = jax.tree_multimap(transpose_from_front, out_axes, ys)
        ys = jax.tree_multimap(
            lambda ax, const, y: (const if ax is broadcast else y), out_axes,
            constants_out, ys)
        return broadcast_in, c, ys
Exemplo n.º 5
0
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
                       scale, translation, kernel: Callable, antialias: bool):
    dtype = jnp.result_type(scale, translation)
    inv_scale = 1. / scale
    # When downsampling the kernel should be scaled since we want to low pass
    # filter and interpolate, but when upsampling it should not be since we only
    # want to interpolate.
    kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
    sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
                translation * inv_scale - 0.5)
    x = (jnp.abs(sample_f[jnp.newaxis, :] -
                 jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
         kernel_scale)
    weights = kernel(x)

    total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
    weights = jnp.where(
        jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
        jnp.divide(weights,
                   jnp.where(total_weight_sum != 0, total_weight_sum, 1)), 0)
    # Zero out weights where the sample location is completely outside the input
    # range.
    # Note sample_f has already had the 0.5 removed, hence the weird range below.
    input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
    return jnp.where(
        jnp.logical_and(sample_f >= -0.5,
                        sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
        weights, 0)
Exemplo n.º 6
0
def test_log_prob_gradient(jax_dist, sp_dist, params):
    if jax_dist is dist.LKJCholesky:
        pytest.skip('we have separated tests for LKJCholesky distribution')
    rng = random.PRNGKey(0)

    value = jax_dist(*params).sample(rng)

    def fn(*args):
        return np.sum(jax_dist(*args).log_prob(value))

    eps = 1e-3
    for i in range(len(params)):
        if params[i] is None or np.result_type(
                params[i]) in (np.int32, np.int64):
            continue
        actual_grad = jax.grad(fn, i)(*params)
        args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
        fn_lhs = fn(*args_lhs)
        fn_rhs = fn(*args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert np.shape(actual_grad) == np.shape(params[i])
        if i == 0 and jax_dist is dist.Delta:
            # grad w.r.t. `value` of Delta distribution will be 0
            # but numerical value will give nan (= inf - inf)
            expected_grad = 0.
        assert_allclose(np.sum(actual_grad),
                        expected_grad,
                        rtol=0.01,
                        atol=1e-3)
Exemplo n.º 7
0
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
            antialias: bool, precision):
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        return _resize_nearest(image, shape)
    assert isinstance(method, ResizeMethod)
    kernel = _kernels[method]

    if not jnp.issubdtype(image.dtype, jnp.inexact):
        image = lax.convert_element_type(image,
                                         jnp.result_type(image, jnp.float32))
    # Skip dimensions that have scale=1 and translation=0, this is only possible
    # since all of the current resize methods (kernels) are interpolating, so the
    # output = input under an identity warp.
    spatial_dims = tuple(
        i for i in range(len(shape))
        if not core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = [
        1.0 if core.symbolic_equal_dim(
            shape[d], 0) else core.dimension_as_value(shape[d]) /
        core.dimension_as_value(image.shape[d]) for d in spatial_dims
    ]
    return _scale_and_translate(image, shape, spatial_dims, scale,
                                [0.] * len(spatial_dims), kernel, antialias,
                                precision)
Exemplo n.º 8
0
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    if n_max == 0:
        return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(
            jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,)
        )[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate(
            [
                jnp.expand_dims(n_max - n, -1),
                jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,)),
            ],
            -1,
        )
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(
        jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
        jnp.expand_dims(indices_2D, axis=-1),
        jnp.ones(indices_2D.shape, dtype=indices.dtype),
    )
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
Exemplo n.º 9
0
def test_log_prob_gradient(jax_dist, sp_dist, params):
    if jax_dist is dist.LKJCholesky:
        pytest.skip('we have separated tests for LKJCholesky distribution')
    rng = random.PRNGKey(0)

    def fn(args, value):
        return np.sum(jax_dist(*args).log_prob(value))

    value = jax_dist(*params).sample(rng)
    actual_grad = jax.grad(fn)(params, value)
    assert len(actual_grad) == len(params)

    eps = 1e-3
    for i in range(len(params)):
        if np.result_type(params[i]) in (np.int32, np.int64):
            continue
        args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
        fn_lhs = fn(args_lhs, value)
        fn_rhs = fn(args_rhs, value)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert np.shape(actual_grad[i]) == np.shape(params[i])
        assert_allclose(np.sum(actual_grad[i]),
                        expected_grad,
                        rtol=0.01,
                        atol=1e-3)
Exemplo n.º 10
0
    def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
        """
        :param IntegratorState z_info: The initial integrator state.
        :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng_key, rng_key_ss = random.split(rng_key)
        inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv = _initialize_mass_matrix(
            z_info[0], inverse_mass_matrix, dense_mass
        )

        if adapt_step_size:
            step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss)
        ss_state = ss_init(jnp.log(10 * step_size))

        if isinstance(inverse_mass_matrix, dict):
            size = {k: v.shape for k, v in inverse_mass_matrix.items()}
        else:
            size = inverse_mass_matrix.shape[-1]
        mm_state = mm_init(size)

        window_idx = jnp.array(0, dtype=jnp.result_type(int))
        return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv,
                             ss_state, mm_state, window_idx, rng_key)
Exemplo n.º 11
0
 def __init__(self, name, shape):
     prior_base = UniformBase(shape, jnp.result_type(float))
     super().__init__(name,
                      shape,
                      parents=[],
                      tracked=True,
                      prior_base=prior_base)
Exemplo n.º 12
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)), params)
        return SVIState(self.optim.init(params), rng_key)
Exemplo n.º 13
0
 def enumerate_support(self, expand=True):
     n = self.event_shape[-1]
     values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
     values = values.reshape((n, ) + (1, ) * len(self.batch_shape) + (n, ))
     if expand:
         values = jnp.broadcast_to(values, (n, ) + self.batch_shape + (n, ))
     return values
Exemplo n.º 14
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     probs = self.probs
     dtype = jnp.result_type(probs)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
Exemplo n.º 15
0
def _kth_arnoldi_iteration(k, A, M, V, H):
  """
  Performs a single (the k'th) step of the Arnoldi process. Thus,
  adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
  and that vectors overlaps with the existing Krylov vectors to
  H[k, :]. The tolerance 'tol' sets the threshold at which an invariant
  subspace is declared to have been found, in which case in which case the new
  vector is taken to be the zero vector.
  """
  dtype = jnp.result_type(*tree_leaves(V))
  eps = jnp.finfo(dtype).eps

  v = tree_map(lambda x: x[..., k], V)  # Gets V[:, k]
  v = M(A(v))
  _, v_norm_0 = _safe_normalize(v)
  v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2)

  tol = eps * v_norm_0
  unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
  V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)

  h = h.at[k + 1].set(v_norm_1.astype(dtype))
  H = H.at[k, :].set(h)
  breakdown = v_norm_1 == 0.
  return V, H, breakdown
Exemplo n.º 16
0
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):

  # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
  bs = _vdot_real_tree(b, b)
  atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))

  # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method

  def cond_fun(value):
    _, r, gamma, _, k = value
    rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
    return (rs > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype)
    x_ = _add(x, _mul(alpha, p))
    r_ = _sub(r, _mul(alpha, Ap))
    z_ = M(r_)
    gamma_ = _vdot_real_tree(r_, z_).astype(dtype)
    beta_ = gamma_ / gamma
    p_ = _add(z_, _mul(beta_, p))
    return x_, r_, gamma_, p_, k + 1

  r0 = _sub(b, A(x0))
  p0 = z0 = M(r0)
  dtype = jnp.result_type(*tree_leaves(p0))
  gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)

  return x_final
Exemplo n.º 17
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     logits = self.logits
     dtype = jnp.result_type(logits)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / -softplus(logits))
Exemplo n.º 18
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
     return self.base_dist.icdf(u * self._cdf_at_high)
Exemplo n.º 19
0
def _jacobian_cplx(forward_fn: Callable, params: PyTree, samples: Array,
                   _build_fn: Callable) -> PyTree:
    """Calculates one Jacobian entry.
    Assumes the function is R→C, backpropagates 1 and -1j

    Args:
        forward_fn: the log wavefunction ln Ψ
        params : a pytree of parameters p
        σ : a single sample (vector)

    Returns:
        The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree
    """
    y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples)
    gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y)))
    gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y)))
    return _build_fn(gr, gi)
Exemplo n.º 20
0
 def _inverse(self, y):
     size = self.permutation.size
     permutation_inv = ops.index_update(
         jnp.zeros(size, dtype=jnp.result_type(int)),
         self.permutation,
         jnp.arange(size),
     )
     return y[..., permutation_inv]
Exemplo n.º 21
0
 def _inverse(self, y):
     size = self.permutation.size
     permutation_inv = (
         jnp.zeros(size, dtype=jnp.result_type(int))
         .at[self.permutation]
         .set(jnp.arange(size))
     )
     return y[..., permutation_inv]
Exemplo n.º 22
0
def _compute_stats(x, axes):
  # promote x to at least float32, this avoids half precision computation
  # but preserves double or complex floating points
  x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
  mean = jnp.mean(x, axes)
  mean2 = jnp.mean(jnp.square(x), axes)
  # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
  # to floating point round-off errors.
  var = jnp.maximum(0., mean2 - jnp.square(mean))
  return mean, var
Exemplo n.º 23
0
    def value_and_grad_f(*args, **kwargs):
        f = lu.wrap_init(fun, kwargs)
        f_partial, dyn_args = argnums_partial(f, argnums, args)
        ans, vjp_py = vjp(f_partial, *dyn_args)

        g = vjp_py(
            jnp.ones((
            ), jnp.result_type(ans)) if initial_grad is None else initial_grad)
        g = g[0] if isinstance(argnums, int) else g
        return (ans, g)
Exemplo n.º 24
0
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):
    """
  Implements a single restart of GMRES. The restart-dimensional Krylov subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation builds the QR factorization during the Arnoldi process.
  """
    # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
    #  residual = _sub(b, A(x0))
    #  unit_residual, beta = _safe_normalize(residual)

    V = tree_map(
        lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim +
                          ((0, restart), )),
        unit_residual,
    )
    dtype = jnp.result_type(*tree_leaves(b))
    R = jnp.eye(restart, restart + 1, dtype=dtype)  # eye to avoid constructing
    # a singular matrix in case
    # of early termination.
    b_norm = _norm_tree(b)

    givens = jnp.zeros((restart, 2), dtype=dtype)
    beta_vec = jnp.zeros((restart + 1), dtype=dtype)
    beta_vec = beta_vec.at[0].set(residual_norm)

    def loop_cond(carry):
        k, err, _, _, _, _ = carry
        return jnp.logical_and(k < restart, err > inner_tol)

    def arnoldi_qr_step(carry):
        k, _, V, R, beta_vec, givens = carry
        V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol)
        R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
        R = R.at[k, :].set(R_row[:])
        cs, sn = givens[k, :] * beta_vec[k]
        beta_vec = beta_vec.at[k].set(cs)
        beta_vec = beta_vec.at[k + 1].set(sn)
        err = jnp.abs(sn) / b_norm
        return k + 1, err, V, R, beta_vec, givens

    carry = (0, residual_norm, V, R, beta_vec, givens)
    carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
    k, residual_norm, V, R, beta_vec, _ = carry
    del k  # Until we figure out how to pass this to the user.

    y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
    Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)
    dx = M(Vy)

    x = _add(x0, dx)
    residual = _sub(b, A(x))
    unit_residual, residual_norm = _safe_normalize(residual)
    return x, unit_residual, residual_norm
Exemplo n.º 25
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
     loc = self.base_dist.loc
     sign = jnp.where(loc >= self.low, 1.0, -1.0)
     return (1 - sign) * loc + sign * self.base_dist.icdf(
         (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
     )
Exemplo n.º 26
0
 def init_fn(prox_center=0.0):
     """
     :param float prox_center: A parameter introduced in reference [1] which
         pulls the primal sequence towards it. Defaults to 0.
     :return: initial state for the scheme.
     """
     x_t = jnp.zeros(())
     x_avg = jnp.zeros(())  # average of primal sequence
     g_avg = jnp.zeros(())  # average of dual sequence
     t = jnp.array(0, dtype=jnp.result_type(int))
     return x_t, x_avg, g_avg, t, prox_center
Exemplo n.º 27
0
 def _sample_n(self, key: PRNGKey, n: int) -> Array:
     """See `Distribution._sample_n`."""
     out_shape = (n, ) + self.batch_shape
     dtype = jnp.result_type(self._loc, self._scale)
     uniform = jax.random.uniform(key,
                                  shape=out_shape,
                                  dtype=dtype,
                                  minval=jnp.finfo(dtype).tiny,
                                  maxval=1.)
     rnd = jnp.log(uniform) - jnp.log1p(-uniform)
     return self._scale * rnd + self._loc
Exemplo n.º 28
0
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key,
               max_delta_energy=1000., max_tree_depth=10):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng_key: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    latent_size = jnp.size(ravel_pytree(r)[0])
    r_ckpts = jnp.zeros((max_tree_depth, latent_size))
    r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size))

    tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
                    depth=0, weight=jnp.zeros(()), r_sum=r, turning=jnp.array(False),
                    diverging=jnp.array(False),
                    sum_accept_probs=jnp.zeros(()),
                    num_proposals=jnp.array(0, dtype=jnp.result_type(int)))

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
                            going_right, doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng_key)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree
Exemplo n.º 29
0
def cartesian_product(*arrays):
    """
    IN: any number of np arrays of same length
    OUT: cartesian product of the arrays
    """
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        #         arr[...,i] = a
        arr = index_update(arr, index[..., i], a)
    return arr.reshape(-1, la)
Exemplo n.º 30
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs
        )
        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (
                site["type"] == "sample"
                and (not site["is_observed"])
                and site["fn"].support.is_discrete
                and not self.loss.can_infer_discrete
            ):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)