def testDerivativeIsBoundedWhenAlphaIsBelow2(self): # Assert that |d_x| < |x|/scale^2 when alpha <= 2. _, _, x, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs() mask = jnp.isfinite(alpha) & (alpha <= 2) grad = jnp.abs(d_x[mask]) bound = ((jnp.abs(x[mask]) + (300. * jnp.finfo(jnp.float32).eps)) / scale[mask]**2) self.assertTrue(jnp.all(grad <= bound))
def binary_crossentropy_loss(params, predict, data): inputs, targets = data probs = predict(params, inputs) eps = jnp.finfo(probs.dtype).eps probs = jnp.clip(probs, eps, 1 - eps) loss = -(jsp.special.xlogy(targets, probs) + jsp.special.xlogy(1 - targets, 1 - probs)).mean() return loss
def log_abs_det_jacobian(self, x, y, intermediates=None): # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - z)) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) z = jnp.clip(expit(x), a_min=jnp.finfo(x.dtype).tiny) # XXX we use the identity 1 - z = z * exp(-x) to not worry about # the case z ~ 1 return jnp.sum(jnp.log(y[..., :-1] * z) - x, axis=-1)
def testDerivativeIsMonotonicWrtX(self): # Check that the loss increases monotonically with |x|. _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs() # This is just to suppress a warning below. d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x)) mask = jnp.isfinite(alpha) & (jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps)) chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
def testLossIsBoundedWhenAlphaIsNegative(self): # Assert that loss < (alpha - 2)/alpha when alpha < 0. _, loss, _, alpha, _, _, _, _ = self._precompute_lossfun_inputs() mask = alpha < 0. min_val = jnp.finfo(jnp.float32).min alpha_clipped = jnp.maximum(min_val, alpha[mask]) self.assertTrue( jnp.all(loss[mask] <= ((alpha_clipped - 2.) / alpha_clipped)))
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng, init_step_size): """ Finds a reasonable step size by tuning `init_step_size`. This function is used to avoid working with a too large or too small step size in HMC. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman :param potential_fn: A callable to compute potential energy. :param kinetic_fn: A callable to compute kinetic energy. :param momentum_generator: A generator to get a random momentum variable. :param inverse_mass_matrix: Inverse of mass matrix. :param position: Current position of the particle. :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. :param float init_step_size: Initial step size to be tuned. :return: a reasonable value for step size. :rtype: float """ # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. target_accept_prob = np.log(0.8) _, vv_update = velocity_verlet(potential_fn, kinetic_fn) z = position potential_energy, z_grad = value_and_grad(potential_fn)(z) tiny = np.finfo(get_dtype(init_step_size)).tiny def _body_fn(state): step_size, _, direction, rng = state rng, rng_momentum = random.split(rng) # scale step_size: increase 2x or decrease 2x depends on direction; # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). step_size = (2.0 ** direction) * step_size r = momentum_generator(inverse_mass_matrix, rng_momentum) _, r_new, potential_energy_new, _ = vv_update(step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad)) energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = np.where(target_accept_prob < -delta_energy, 1, -1) return step_size, direction, direction_new, rng def _cond_fn(state): step_size, last_direction, direction, _ = state # condition to run only if step_size is not so small or we are not decreasing step_size not_small_step_size_cond = (step_size > tiny) | (direction >= 0) return not_small_step_size_cond & ((last_direction == 0) | (direction == last_direction)) step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng)) return step_size
def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.base_dist.icdf(u * self._cdf_at_high)
def _projector_subspace(P, H, rank, maxiter=2): """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`. The subspaces are computed using the naiive QR eigendecomposition algorithm, which converges very quickly due to the sharp separation between the relevant eigenvalues of the projector. Args: P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank` eigenpairs. H: The aforementioned Hermitian matrix, which is used to track convergence. rank: Rank of `P`. maxiter: Maximum number of iterations. Returns: Vm, Vp: Isometries into the eigenspaces described in the docstring. """ # Choose an initial guess: the `rank` largest-norm columns of P. column_norms = jnp.linalg.norm(P, axis=1) sort_idxs = jnp.argsort(column_norms) X = P[:, sort_idxs] X = X[:, :rank] H_norm = jnp.linalg.norm(H) thresh = 10 * jnp.finfo(X.dtype).eps * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): Q, _ = jnp.linalg.qr(X, mode="complete") V1 = Q[:, :rank] V2 = Q[:, rank:] # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST) error_matrix = jnp.dot(error_matrix, V1, precision=lax.Precision.HIGHEST) error = jnp.linalg.norm(error_matrix) / H_norm return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) return V1, V2
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64): raise SkipTest("can't test float64 agreement") bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64 numpy_bits = np.array(1., dtype).view(bits_dtype) xla_bits = api.jit(lambda: lax.bitcast_convert_type( np.array(1., dtype), bits_dtype))() self.assertEqual(numpy_bits, xla_bits)
def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") m, n = shape if (method == "qdwh" and ((side == "left" and m >= n) or (side == "right" and m < n))): raise unittest.SkipTest("method=qdwh does not support these sizes") matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 500 * float(jnp.finfo(matrix.dtype).eps) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose(posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): self.assertEqual(negative_ev, 0) if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose(matrix, recon, atol=tol * jnp.linalg.norm(matrix))
def lossfun(x, alpha, scale): r"""Implements the general form of the loss. This implements the rho(x, \alpha, c) function described in "A General and Adaptive Robust Loss Function", Jonathan T. Barron, https://arxiv.org/abs/1701.03077. Args: x: The residual for which the loss is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. alpha: The shape parameter of the loss (\alpha in the paper), where more negative values produce a loss with more robust behavior (outliers "cost" less), and more positive values produce a loss with less robust behavior (outliers are penalized more heavily). Alpha can be any value in [-infinity, infinity], but the gradient of the loss with respect to alpha is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth interpolation between several discrete robust losses: alpha=-Infinity: Welsch/Leclerc Loss. alpha=-2: Geman-McClure loss. alpha=0: Cauchy/Lortentzian loss. alpha=1: Charbonnier/pseudo-Huber loss. alpha=2: L2 loss. scale: The scale parameter of the loss. When |x| < scale, the loss is an L2-like quadratic bowl, and when |x| > scale the loss function takes on a different shape according to alpha. Returns: The losses for each element of x, in the same shape as x. """ eps = jnp.finfo(jnp.float32).eps # `scale` must be > 0. scale = jnp.maximum(eps, scale) # The loss when alpha == 2. This will get reused repeatedly. loss_two = 0.5 * (x / scale)**2 # "Safe" versions of log1p and expm1 that will not NaN-out. log1p_safe = lambda x: jnp.log1p(jnp.minimum(x, 3e37)) expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 87.5)) # The loss when not in one of the special casess. # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. a = jnp.where(alpha >= 0, jnp.ones_like(alpha), -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha)) # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. b = jnp.maximum(eps, jnp.abs(alpha - 2)) loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * alpha) - 1) # Select which of the cases of the loss to return as a function of alpha. return jnp.where( alpha == -jnp.inf, -expm1_safe(-loss_two), jnp.where( alpha == 0, log1p_safe(loss_two), jnp.where( alpha == 2, loss_two, jnp.where(alpha == jnp.inf, expm1_safe(loss_two), loss_ow))))
def categorical_sample(key, probs): """Sample from a set of discrete probabilities.""" probs = probs / probs.sum(axis=-1, keepdims=True) cpi = jnp.cumsum(probs, axis=-1) eps = jnp.finfo(probs.dtype).eps rnds = jax.random.uniform(key=key, shape=probs.shape[:-1] + (1, ), dtype=probs.dtype, minval=eps) return jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1)
def visualize_depth(x, acc, lo=None, hi=None): """Visualizes depth maps.""" depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps) return visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn, lo=lo, hi=hi)
def visualize_cmap(value, weight, colormap, lo=None, hi=None, percentile=99., curve_fn=lambda x: x, modulus=None, matte_background=True): """Visualize a 1D image and a 1D weighting according to some colormap. Args: value: A 1D image. weight: A weight map, in [0, 1]. colormap: A colormap function. lo: The lower bound to use when rendering, if None then use a percentile. hi: The upper bound to use when rendering, if None then use a percentile. percentile: What percentile of the value map to crop to when automatically generating `lo` and `hi`. Depends on `weight` as well as `value'. curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. matte_background: If True, matte the image over a checkerboard. Returns: A colormap rendering. """ # Identify the values that bound the middle of `value' according to `weight`. lo_auto, hi_auto = math.weighted_percentile( value, weight, [50 - percentile / 2, 50 + percentile / 2]) # If `lo` or `hi` are None, use the automatically-computed bounds above. eps = jnp.finfo(jnp.float32).eps lo = lo or (lo_auto - eps) hi = hi or (hi_auto + eps) # Curve all values. value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] # Wrap the values around if requested. if modulus: value = jnp.mod(value, modulus) / modulus else: # Otherwise, just scale to [0, 1]. value = jnp.nan_to_num( jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) if colormap: colorized = colormap(value)[:, :, :3] else: assert len(value.shape) == 3 and value.shape[-1] == 3 colorized = value return matte(colorized, weight) if matte_background else colorized
def testRngUniform(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, np.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
def debiased_moments(self): """Returns debiased moments as in Adam.""" tiny = jnp.finfo(self.decay_product).tiny debias = 1.0 / jnp.maximum(1 - self.decay_product, tiny) mean = jax.tree_map(lambda m1: m1 * debias, self.mu) # This computation of the variance may lose some numerical precision, if # the mean is not approximately zero. variance = jax.tree_map( lambda m2, m: jnp.maximum(0.0, m2 * debias - jnp.square(m)), self.nu, mean) return EmaMoments(mean=mean, variance=variance)
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
def inverse_softplus(y): """Inverse of jax.nn.softplus, adapted from TensorFlow Probability.""" threshold = jnp.log(jnp.finfo(jnp.float32).eps) + 2. is_too_small = y < jnp.exp(threshold) is_too_large = y > -threshold too_small_value = jnp.log(y) too_large_value = y y = jnp.where(is_too_small | is_too_large, 1., y) x = y + jnp.log(-jnp.expm1(-y)) return jnp.where(is_too_small, too_small_value, jnp.where(is_too_large, too_large_value, x))
def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high )
def _max_condition_number_to_be_non_singular(self): """Return the maximum condition number that we consider nonsingular.""" with ops.name_scope("max_nonsingular_condition_number"): dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps eps = _ops.cast( math_ops.reduce_max([ 100., _ops.cast(self.range_dimension_tensor(), self.dtype), _ops.cast(self.domain_dimension_tensor(), self.dtype) ]), self.dtype) * dtype_eps return 1. / eps
def _sample_n(self, key: PRNGKey, n: int) -> Array: """See `Distribution._sample_n`.""" out_shape = (n, ) + self.batch_shape dtype = jnp.result_type(self._loc, self._scale) uniform = jax.random.uniform(key, shape=out_shape, dtype=dtype, minval=jnp.finfo(dtype).tiny, maxval=1.) rnd = jnp.log(uniform) - jnp.log1p(-uniform) return self._scale * rnd + self._loc
def _check_symmetry(x: jnp.ndarray) -> bool: """Check if the array is symmetric.""" m, n = x.shape eps = jnp.finfo(x.dtype).eps tol = 50.0 * eps is_symmetric = False if m == n: if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: is_symmetric = True return is_symmetric
def _svd_tall_and_square_input( a: Any, hermitian: bool, compute_uv: bool, max_iterations: int) -> Union[Any, Sequence[Any]]: """Singular value decomposition for m x n matrix and m >= n. Args: a: A matrix of shape `m x n` with `m >= n`. hermitian: True if `a` is Hermitian. compute_uv: Whether to compute also `u` and `v` in addition to `s`. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`, `s` is vector of length `n` containing the singular values in the descending order, `v` is a unitary matrix of shape `n x n`, and `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ u, h, _, _ = lax.linalg.qdwh(a, is_hermitian=hermitian, max_iterations=max_iterations) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. v, s = lax.linalg.eigh(h) # Flips the singular values in descending order. s_out = jnp.flip(s) if not compute_uv: return s_out # Reorders eigenvectors. v_out = jnp.fliplr(v) u_out = u @ v_out # Makes correction if computed `u` from qdwh is not unitary. # Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and # efficient spectral divide and conquer algorithms for the symmetric # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): u_out, r = lax.linalg.qr(u_out, full_matrices=False) u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r))) return u_out eps = float(jnp.finfo(a.dtype).eps) u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0], correct_rank_deficiency, lambda u_out: u_out, operand=(u_out)) return (u_out, s_out, v_out)
def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)
def _inverse(self, y): # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) finfo = jnp.finfo(y.dtype) remainder = jnp.clip(remainder, a_min=finfo.tiny) t = y / remainder # inverse of tanh t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps) return jnp.arctanh(t)
def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray: real_dtype = jnp.finfo(dtype).dtype m = jax.lax.convert_element_type(self.mean, dtype) s = jax.lax.convert_element_type(self.stddev, real_dtype) is_complex = jnp.issubdtype(dtype, jnp.complexfloating) if is_complex: shape = [2, *shape] unscaled = jax.random.truncated_normal(hk.next_rng_key(), -2., 2., shape, real_dtype) if is_complex: unscaled = unscaled[0] + 1j * unscaled[1] return s * unscaled + m
def test_no_privacy(self): """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" dp_agg = privacy.differentially_private_aggregate( l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0., seed=0) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads) for _ in range(3): updates, state = update_fn(self.per_eg_grads, state) chex.assert_tree_all_close(updates, mean_grads)
def testRngUniform(self, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest("random.uniform() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
def makegauss2D(shape=(3, 3), sigma=0.5): """ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma]) """ m, n = [(ss - 1.0) / 2.0 for ss in shape] y, x = jnp.meshgrid(jnp.arange(-m, m + 1), jnp.arange(-n, n + 1)) h = jnp.exp(-(x * x + y * y) / (2.0 * sigma * sigma)) h = h.at[h < jnp.finfo(h.dtype).eps * h.max()].set(0) sumh = h.sum() if sumh != 0: h /= sumh return h
def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" shape = fn.shape() fn, expand_shape, event_dim = self._unwrap(fn) transform = uniform_reparam_transform(fn) tiny = jnp.finfo(jnp.result_type(float)).tiny x = numpyro.sample( "{}_base".format(name), dist.Uniform(tiny, 1).expand(shape).to_event(event_dim).mask(False), ) # Simulate a numpyro.deterministic() site. return None, transform(x)