def type_assert(inputs: Union[ArrayLike, List[ArrayLike]], expected_types: Union[Type[Scalar], List[Type[Scalar]]]): """Checks that the type of all inputs matches specified expected_types. Args: inputs: list of inputs. expected_types: list of expected types associated with each input; if all inputs have same type, a single type may be passed as `expected_types`. Raises: ValueError: if the length of inputs and expected_types do not match. """ if not isinstance(inputs, list): inputs = [inputs] if not isinstance(expected_types, list): expected_types = [expected_types] * len(inputs) if len(inputs) != len(expected_types): raise ValueError("Length of inputs and expected_types must match.") for x, expected in zip(inputs, expected_types): if jnp.issubdtype(expected, jnp.floating): parent = jnp.floating elif jnp.issubdtype(expected, jnp.integer): parent = jnp.integer else: raise ValueError( "Error in type compatibility check, unsupported dtype" " {}".format(expected)) if not jnp.issubdtype(jnp.result_type(x), parent): raise ValueError("Error in type compatibility check, found {} but " "expected {}.".format(jnp.result_type(x), expected))
def assert_type(inputs: Union[Scalar, Union[Array, Sequence[Array]]], expected_types: Union[Type[Scalar], Sequence[Type[Scalar]]]): """Checks that the type of all `inputs` matches specified `expected_types`. Valid usages include: ``` assert_type(7, int) assert_type(7.1, float) assert_type(False, bool) assert_type([7, 8], int) assert_type([7, 7.1], [int, float]) assert_type(np.array(7), int) assert_type(np.array(7.1), float) assert_type(jnp.array(7), int) assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float]) ``` Args: inputs: array or sequence of arrays or scalars. expected_types: sequence of expected types associated with each input; if all inputs have same type, a single type may be passed as `expected_types`. Raises: AssertionError: if the length of `inputs` and `expected_types` don't match; if `expected_types` contains unsupported pytype; if the types of input do not match the expected types. """ if not isinstance(inputs, (list, tuple)): inputs = [inputs] if not isinstance(expected_types, (list, tuple)): expected_types = [expected_types] * len(inputs) errors = [] if len(inputs) != len(expected_types): raise AssertionError( f"Length of `inputs` and `expected_types` must match, " f"got {len(inputs)} != {len(expected_types)}.") for idx, (x, expected) in enumerate(zip(inputs, expected_types)): if jnp.issubdtype(expected, jnp.floating): parent = jnp.floating elif jnp.issubdtype(expected, jnp.integer): parent = jnp.integer elif jnp.issubdtype(expected, jnp.bool_): parent = jnp.bool_ else: raise AssertionError( f"Error in type compatibility check, unsupported dtype '{expected}'." ) if not jnp.issubdtype(jnp.result_type(x), parent): errors.append((idx, jnp.result_type(x), expected)) if errors: msg = "; ".join("input {} has type {} but expected {}".format(*err) for err in errors) raise AssertionError("Error in type compatibility check: " + msg + ".")
def test_ravel_pytree(pytree): flat, unravel_fn = ravel_pytree(pytree) unravel = unravel_fn(flat) tree_flatten( tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all( tree_flatten( tree_multimap( lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree))[0])
def scan_fn(broadcast_in, init, *args): xs = jax.tree_multimap(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_multimap( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_multimap( lambda ax, y: (y if ax is broadcast else ()), out_axes, ys) return broadcast_out, ys else: ys = jax.tree_multimap( lambda ax, y: (() if ax is broadcast else y), out_axes, ys) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat) c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse) ys = jax.tree_multimap(transpose_from_front, out_axes, ys) ys = jax.tree_multimap( lambda ax, const, y: (const if ax is broadcast else y), out_axes, constants_out, ys) return broadcast_in, c, ys
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, scale, translation, kernel: Callable, antialias: bool): dtype = jnp.result_type(scale, translation) inv_scale = 1. / scale # When downsampling the kernel should be scaled since we want to low pass # filter and interpolate, but when upsampling it should not be since we only # want to interpolate. kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1. sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - translation * inv_scale - 0.5) x = (jnp.abs(sample_f[jnp.newaxis, :] - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) / kernel_scale) weights = kernel(x) total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) weights = jnp.where( jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps), jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)), 0) # Zero out weights where the sample location is completely outside the input # range. # Note sample_f has already had the 0.5 removed, hence the weird range below. input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 return jnp.where( jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
def test_log_prob_gradient(jax_dist, sp_dist, params): if jax_dist is dist.LKJCholesky: pytest.skip('we have separated tests for LKJCholesky distribution') rng = random.PRNGKey(0) value = jax_dist(*params).sample(rng) def fn(*args): return np.sum(jax_dist(*args).log_prob(value)) eps = 1e-3 for i in range(len(params)): if params[i] is None or np.result_type( params[i]) in (np.int32, np.int64): continue actual_grad = jax.grad(fn, i)(*params) args_lhs = [p if j != i else p - eps for j, p in enumerate(params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(params)] fn_lhs = fn(*args_lhs) fn_rhs = fn(*args_rhs) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2. * eps) assert np.shape(actual_grad) == np.shape(params[i]) if i == 0 and jax_dist is dist.Delta: # grad w.r.t. `value` of Delta distribution will be 0 # but numerical value will give nan (= inf - inf) expected_grad = 0. assert_allclose(np.sum(actual_grad), expected_grad, rtol=0.01, atol=1e-3)
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool, precision): if len(shape) != image.ndim: msg = ( 'shape must have length equal to the number of dimensions of x; ' f' {shape} vs {image.shape}') raise ValueError(msg) if isinstance(method, str): method = ResizeMethod.from_string(method) if method == ResizeMethod.NEAREST: return _resize_nearest(image, shape) assert isinstance(method, ResizeMethod) kernel = _kernels[method] if not jnp.issubdtype(image.dtype, jnp.inexact): image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32)) # Skip dimensions that have scale=1 and translation=0, this is only possible # since all of the current resize methods (kernels) are interpolating, so the # output = input under an identity warp. spatial_dims = tuple( i for i in range(len(shape)) if not core.symbolic_equal_dim(image.shape[i], shape[i])) scale = [ 1.0 if core.symbolic_equal_dim( shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d]) for d in spatial_dims ] return _scale_and_translate(image, shape, spatial_dims, scale, [0.] * len(spatial_dims), kernel, antialias, precision)
def _multinomial(key, p, n, n_max, shape=()): if jnp.shape(n) != jnp.shape(p)[:-1]: broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) n = jnp.broadcast_to(n, broadcast_shape) p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) shape = shape or p.shape[:-1] if n_max == 0: return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int)) # get indices from categorical distribution then gather the result indices = categorical(key, p, (n_max,) + shape) # mask out values when counts is heterogeneous if jnp.ndim(n) > 0: mask = promote_shapes( jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,) )[0] mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype) excess = jnp.concatenate( [ jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,)), ], -1, ) else: mask = 1 excess = 0 # NB: we transpose to move batch shape to the front indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T samples_2D = vmap(_scatter_add_one, (0, 0, 0))( jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype), jnp.expand_dims(indices_2D, axis=-1), jnp.ones(indices_2D.shape, dtype=indices.dtype), ) return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
def test_log_prob_gradient(jax_dist, sp_dist, params): if jax_dist is dist.LKJCholesky: pytest.skip('we have separated tests for LKJCholesky distribution') rng = random.PRNGKey(0) def fn(args, value): return np.sum(jax_dist(*args).log_prob(value)) value = jax_dist(*params).sample(rng) actual_grad = jax.grad(fn)(params, value) assert len(actual_grad) == len(params) eps = 1e-3 for i in range(len(params)): if np.result_type(params[i]) in (np.int32, np.int64): continue args_lhs = [p if j != i else p - eps for j, p in enumerate(params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(params)] fn_lhs = fn(args_lhs, value) fn_rhs = fn(args_rhs, value) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2. * eps) assert np.shape(actual_grad[i]) == np.shape(params[i]) assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.01, atol=1e-3)
def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None): """ :param IntegratorState z_info: The initial integrator state. :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness. :param float step_size: Initial step size. :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``, inverse of mass matrix will be an identity matrix with size is decided by the argument `mass_matrix_size`. :param int mass_matrix_size: Size of the mass matrix. :return: initial state of the adapt scheme. """ rng_key, rng_key_ss = random.split(rng_key) inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv = _initialize_mass_matrix( z_info[0], inverse_mass_matrix, dense_mass ) if adapt_step_size: step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss) ss_state = ss_init(jnp.log(10 * step_size)) if isinstance(inverse_mass_matrix, dict): size = {k: v.shape for k, v in inverse_mass_matrix.items()} else: size = inverse_mass_matrix.shape[-1] mm_state = mm_init(size) window_idx = jnp.array(0, dtype=jnp.result_type(int)) return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, ss_state, mm_state, window_idx, rng_key)
def __init__(self, name, shape): prior_base = UniformBase(shape, jnp.result_type(float)) super().__init__(name, shape, parents=[], tracked=True, prior_base=prior_base)
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), params) return SVIState(self.optim.init(params), rng_key)
def enumerate_support(self, expand=True): n = self.event_shape[-1] values = jnp.identity(n, dtype=jnp.result_type(self.dtype)) values = values.reshape((n, ) + (1, ) * len(self.batch_shape) + (n, )) if expand: values = jnp.broadcast_to(values, (n, ) + self.batch_shape + (n, )) return values
def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs dtype = jnp.result_type(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
def _kth_arnoldi_iteration(k, A, M, V, H): """ Performs a single (the k'th) step of the Arnoldi process. Thus, adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1], and that vectors overlaps with the existing Krylov vectors to H[k, :]. The tolerance 'tol' sets the threshold at which an invariant subspace is declared to have been found, in which case in which case the new vector is taken to be the zero vector. """ dtype = jnp.result_type(*tree_leaves(V)) eps = jnp.finfo(dtype).eps v = tree_map(lambda x: x[..., k], V) # Gets V[:, k] v = M(A(v)) _, v_norm_0 = _safe_normalize(v) v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2) tol = eps * v_norm_0 unit_v, v_norm_1 = _safe_normalize(v, thresh=tol) V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v) h = h.at[k + 1].set(v_norm_1.astype(dtype)) H = H.at[k, :].set(h) breakdown = v_norm_1 == 0. return V, H, breakdown
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg bs = _vdot_real_tree(b, b) atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol)) # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method def cond_fun(value): _, r, gamma, _, k = value rs = gamma.real if M is _identity else _vdot_real_tree(r, r) return (rs > atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype) x_ = _add(x, _mul(alpha, p)) r_ = _sub(r, _mul(alpha, Ap)) z_ = M(r_) gamma_ = _vdot_real_tree(r_, z_).astype(dtype) beta_ = gamma_ / gamma p_ = _add(z_, _mul(beta_, p)) return x_, r_, gamma_, p_, k + 1 r0 = _sub(b, A(x0)) p0 = z0 = M(r0) dtype = jnp.result_type(*tree_leaves(p0)) gamma0 = _vdot_real_tree(r0, z0).astype(dtype) initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final
def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.base_dist.icdf(u * self._cdf_at_high)
def _jacobian_cplx(forward_fn: Callable, params: PyTree, samples: Array, _build_fn: Callable) -> PyTree: """Calculates one Jacobian entry. Assumes the function is R→C, backpropagates 1 and -1j Args: forward_fn: the log wavefunction ln Ψ params : a pytree of parameters p σ : a single sample (vector) Returns: The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree """ y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples) gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y))) gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y))) return _build_fn(gr, gi)
def _inverse(self, y): size = self.permutation.size permutation_inv = ops.index_update( jnp.zeros(size, dtype=jnp.result_type(int)), self.permutation, jnp.arange(size), ) return y[..., permutation_inv]
def _inverse(self, y): size = self.permutation.size permutation_inv = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) ) return y[..., permutation_inv]
def _compute_stats(x, axes): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) mean = jnp.mean(x, axes) mean2 = jnp.mean(jnp.square(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0., mean2 - jnp.square(mean)) return mean, var
def value_and_grad_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial(f, argnums, args) ans, vjp_py = vjp(f_partial, *dyn_args) g = vjp_py( jnp.ones(( ), jnp.result_type(ans)) if initial_grad is None else initial_grad) g = g[0] if isinstance(argnums, int) else g return (ans, g)
def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M): """ Implements a single restart of GMRES. The restart-dimensional Krylov subspace K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the projection of the true solution into this subspace is returned. This implementation builds the QR factorization during the Arnoldi process. """ # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf # residual = _sub(b, A(x0)) # unit_residual, beta = _safe_normalize(residual) V = tree_map( lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim + ((0, restart), )), unit_residual, ) dtype = jnp.result_type(*tree_leaves(b)) R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing # a singular matrix in case # of early termination. b_norm = _norm_tree(b) givens = jnp.zeros((restart, 2), dtype=dtype) beta_vec = jnp.zeros((restart + 1), dtype=dtype) beta_vec = beta_vec.at[0].set(residual_norm) def loop_cond(carry): k, err, _, _, _, _ = carry return jnp.logical_and(k < restart, err > inner_tol) def arnoldi_qr_step(carry): k, _, V, R, beta_vec, givens = carry V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R, inner_tol) R_row, givens = _apply_givens_rotations(H[k, :], givens, k) R = R.at[k, :].set(R_row[:]) cs, sn = givens[k, :] * beta_vec[k] beta_vec = beta_vec.at[k].set(cs) beta_vec = beta_vec.at[k + 1].set(sn) err = jnp.abs(sn) / b_norm return k + 1, err, V, R, beta_vec, givens carry = (0, residual_norm, V, R, beta_vec, givens) carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry) k, residual_norm, V, R, beta_vec, _ = carry del k # Until we figure out how to pass this to the user. y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1]) Vy = tree_map(lambda X: _dot(X[..., :-1], y), V) dx = M(Vy) x = _add(x0, dx) residual = _sub(b, A(x)) unit_residual, residual_norm = _safe_normalize(residual) return x, unit_residual, residual_norm
def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high )
def init_fn(prox_center=0.0): """ :param float prox_center: A parameter introduced in reference [1] which pulls the primal sequence towards it. Defaults to 0. :return: initial state for the scheme. """ x_t = jnp.zeros(()) x_avg = jnp.zeros(()) # average of primal sequence g_avg = jnp.zeros(()) # average of dual sequence t = jnp.array(0, dtype=jnp.result_type(int)) return x_t, x_avg, g_avg, t, prox_center
def _sample_n(self, key: PRNGKey, n: int) -> Array: """See `Distribution._sample_n`.""" out_shape = (n, ) + self.batch_shape dtype = jnp.result_type(self._loc, self._scale) uniform = jax.random.uniform(key, shape=out_shape, dtype=dtype, minval=jnp.finfo(dtype).tiny, maxval=1.) rnd = jnp.log(uniform) - jnp.log1p(-uniform) return self._scale * rnd + self._loc
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=1000., max_tree_depth=10): """ Builds a binary tree from the `verlet_state`. This is used in NUTS sampler. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman 2. *A Conceptual Introduction to Hamiltonian Monte Carlo*, Michael Betancourt :param verlet_update: A callable to get a new integrator state given a current integrator state. :param kinetic_fn: A callable to compute kinetic energy. :param verlet_state: Initial integrator state. :param inverse_mass_matrix: Inverse of the mass matrix. :param float step_size: Step size for the current trajectory. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. :param float max_delta_energy: A threshold to decide if the new state diverges (based on the energy difference) too much from the initial integrator state. :return: information of the tree. :rtype: :data:`TreeInfo` """ z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) latent_size = jnp.size(ravel_pytree(r)[0]) r_ckpts = jnp.zeros((max_tree_depth, latent_size)) r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size)) tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current, depth=0, weight=jnp.zeros(()), r_sum=r, turning=jnp.array(False), diverging=jnp.array(False), sum_accept_probs=jnp.zeros(()), num_proposals=jnp.array(0, dtype=jnp.result_type(int))) def _cond_fn(state): tree, _ = state return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging def _body_fn(state): tree, key = state key, direction_key, doubling_key = random.split(key, 3) going_right = random.bernoulli(direction_key) tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, doubling_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts) return tree, key state = (tree, rng_key) tree, _ = while_loop(_cond_fn, _body_fn, state) return tree
def cartesian_product(*arrays): """ IN: any number of np arrays of same length OUT: cartesian product of the arrays """ la = len(arrays) dtype = np.result_type(*arrays) arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype) for i, a in enumerate(np.ix_(*arrays)): # arr[...,i] = a arr = index_update(arr, index[..., i], a) return arr.reshape(-1, la)
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs ) params = {} inv_transforms = {} mutable_state = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site["type"] == "param": constraint = site["kwargs"].pop("constraint", constraints.real) with helpful_support_errors(site): transform = biject_to(constraint) inv_transforms[site["name"]] = transform params[site["name"]] = transform.inv(site["value"]) elif site["type"] == "mutable": mutable_state[site["name"]] = site["value"] elif ( site["type"] == "sample" and (not site["is_observed"]) and site["fn"].support.is_discrete and not self.loss.can_infer_discrete ): s_name = type(self.loss).__name__ warnings.warn( f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" ) if not mutable_state: mutable_state = None self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params, mutable_state = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), (params, mutable_state), ) return SVIState(self.optim.init(params), mutable_state, rng_key)