def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) return i + 1, key, (params, pe, z_grad), is_valid
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform(transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid
def testScaleAndTranslateGradFinite(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def scale_fn(s): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) scale_out = jax.grad(scale_fn)(scale_a) self.assertTrue(jnp.all(jnp.isfinite(scale_out))) def translate_fn(t): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) translate_out = jax.grad(translate_fn)(translation_a) self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
def _param_idx_to_str(idx: int) -> str: param = self.x[idx] if self.sig_x is None: sig = None else: sig = self.sig_x[idx] if self.bounds is None: low = None upp = None else: low = self.bounds[idx, 0] upp = self.bounds[idx, 1] params_str = f" {param}" if sig is not None: params_str += f" +/- {sig}" params_str += "," if low is not None: if jnp.isfinite(low): params_str += f"\t [Lower Bound = {low}]" if upp is not None: if jnp.isfinite(upp): params_str += f"\t [Upper Bound = {upp}]" return params_str
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # XXX: we don't want to apply enum to draw latent samples model_ = model if enum: from numpyro.contrib.funsor import enum as enum_handler if isinstance(model, substitute) and isinstance(model.fn, enum_handler): model_ = substitute(model.fn.fn, data=model.data) elif isinstance(model, enum_handler): model_ = model.fn # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if ( v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete ): constrained_values[k] = v["value"] inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform( subkey, jnp.shape(v), minval=-radius, maxval=radius ) key, subkey = random.split(key) potential_fn = partial( potential_energy, model, model_args, model_kwargs, enum=enum ) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid
def testDerivativeIsMonotonicWrtX(self): # Check that the loss increases monotonically with |x|. _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs() # This is just to suppress a warning below. d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x)) mask = jnp.isfinite(alpha) & (jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps)) chex.assert_tree_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
def keep_step(grad_norm): keep_threshold = p.skip_step_gradient_norm_value if keep_threshold: return jnp.logical_and( jnp.all(jnp.isfinite(grad_norm)), jnp.all(jnp.less(grad_norm, keep_threshold))) else: return jnp.all(jnp.isfinite(grad_norm))
def test_mean_var(jax_dist, sp_dist, params): n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000 d_jax = jax_dist(*params) k = random.PRNGKey(0) samples = d_jax.sample(k, sample_shape=(n,)) # check with suitable scipy implementation if available if sp_dist and not _is_batched_multivariate(d_jax): d_sp = sp_dist(*params) try: sp_mean = d_sp.mean() except TypeError: # mvn does not have .mean() method sp_mean = d_sp.mean # for multivariate distns try .cov first if d_jax.event_shape: try: sp_var = np.diag(d_sp.cov()) except TypeError: # mvn does not have .cov() method sp_var = np.diag(d_sp.cov) except AttributeError: sp_var = d_sp.var() else: sp_var = d_sp.var() assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7) assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7) if np.all(np.isfinite(sp_mean)): assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) if np.all(np.isfinite(sp_var)): assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2) elif jax_dist in [dist.LKJ, dist.LKJCholesky]: if jax_dist is dist.LKJCholesky: corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1)) else: corr_samples = samples dimension, concentration, _ = params # marginal of off-diagonal entries marginal = dist.Beta(concentration + 0.5 * (dimension - 2), concentration + 0.5 * (dimension - 2)) # scale statistics due to linear mapping marginal_mean = 2 * marginal.mean - 1 marginal_std = 2 * np.sqrt(marginal.variance) expected_mean = np.broadcast_to(np.reshape(marginal_mean, np.shape(marginal_mean) + (1, 1)), np.shape(marginal_mean) + d_jax.event_shape) expected_std = np.broadcast_to(np.reshape(marginal_std, np.shape(marginal_std) + (1, 1)), np.shape(marginal_std) + d_jax.event_shape) # diagonal elements of correlation matrices are 1 expected_mean = expected_mean * (1 - np.identity(dimension)) + np.identity(dimension) expected_std = expected_std * (1 - np.identity(dimension)) assert_allclose(np.mean(corr_samples, axis=0), expected_mean, atol=0.01) assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01) else: if np.all(np.isfinite(d_jax.mean)): assert_allclose(np.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2) if np.all(np.isfinite(d_jax.variance)): assert_allclose(np.std(samples, 0), np.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)
def body_fn(iteration, const, state, compute_error): """Carries out sinkhorn iteration. Depending on lse_mode, these iterations can be either in: - log-space for numerical stability. - scaling space, using standard kernel-vector multiply operations. Args: iteration: iteration number const: tuple of constant parameters that do not change throughout the loop, here the geometry and the marginals a, b. state: potential/scaling variables updated in the loop & error log. compute_error: flag to indicate this iteration computes/stores an error Returns: state variables, i.e. errors and updated f_u, g_v potentials. """ geom, a, b, _ = const errors, f_u, g_v = state # compute momentum term if needed, using previously seen errors. w = jax.lax.stop_gradient(jnp.where(iteration >= ( inner_iterations * chg_momentum_from + min_iterations), get_momentum(errors, chg_momentum_from), momentum_default)) # Sinkhorn updates using momentum, in either scaling or potential form. if parallel_dual_updates: old_g_v = g_v if lse_mode: new_g_v = tau_b * geom.update_potential(f_u, g_v, jnp.log(b), iteration, axis=0) g_v = (1.0 - w) * jnp.where(jnp.isfinite(g_v), g_v, 0.0) + w * new_g_v new_f_u = tau_a * geom.update_potential( f_u, old_g_v if parallel_dual_updates else g_v, jnp.log(a), iteration, axis=1) f_u = (1.0 - w) * jnp.where(jnp.isfinite(f_u), f_u, 0.0) + w * new_f_u else: new_g_v = geom.update_scaling(f_u, b, iteration, axis=0) ** tau_b g_v = jnp.where(g_v > 0, g_v, 1) ** (1.0 - w) * new_g_v ** w new_f_u = geom.update_scaling( old_g_v if parallel_dual_updates else g_v, a, iteration, axis=1) ** tau_a f_u = jnp.where(f_u > 0, f_u, 1) ** (1.0 - w) * new_f_u ** w # re-computes error if compute_error is True, else set it to inf. err = jnp.where( jnp.logical_and(compute_error, iteration >= min_iterations), marginal_error(geom, a, b, tau_a, tau_b, f_u, g_v, norm_error, lse_mode), jnp.inf) errors = jax.ops.index_update( errors, jax.ops.index[iteration // inner_iterations, :], err) return errors, f_u, g_v
def logmarglike_lineargaussianmodel_onetransfer(M_T, y, yinvvar, logyinvvar=None): """ Fit linear model to one Gaussian data set, with no (=uniform) prior on the linear components. Parameters ---------- y, yinvvar, logyinvvar : ndarray (n_pix_y) data and data inverse variances. Zeros will be ignored. M_T : ndarray (n_components, n_pix_y) design matrix of linear model Returns ------- logfml : ndarray scalar log likelihood values with parameters marginalised and at best fit theta_map : ndarray (n_components) Best fit MAP parameters theta_cov : ndarray (n_components, n_components) Parameter covariance """ # assert y.shape[-2] == yinvvar.shape[-2] assert y.shape[-1] == yinvvar.shape[-1] # assert y.shape[-1] == 1 assert M_T.shape[-1] == yinvvar.shape[-1] assert np.all(np.isfinite(yinvvar)) # no negative elements assert np.all(np.isfinite(y)) # all finite assert np.all(np.isfinite(M_T)) # all finite assert np.count_nonzero( yinvvar) > 2 # at least two valid (non zero) pixels log2pi = np.log(2.0 * np.pi) nt = np.shape(M_T)[-2] ny = np.count_nonzero(yinvvar) M = np.transpose(M_T) # (n_pix_y, n_components) Myinv = M * yinvvar[:, None] # (n_pix_y, n_components) Hbar = np.matmul(M_T, Myinv) # (n_components, n_components) etabar = np.sum(Myinv * y[:, None], axis=0) # (n_components) theta_map = np.linalg.solve(Hbar, etabar) # (n_components) theta_cov = np.linalg.inv(Hbar) # (n_components, n_components) if logyinvvar is None: logyinvvar = np.where(yinvvar == 0, 0, np.log(yinvvar)) logdetH = np.sum(logyinvvar) # scalar xi1 = -0.5 * (ny * log2pi - logdetH + np.sum(y * y * yinvvar)) # scalar sign, logdetHbar = np.linalg.slogdet(Hbar) xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map)) logfml = xi1 - xi2 return logfml, theta_map, theta_cov
def assert_tree_all_finite(tree_like: ArrayTree): """Assert all tensor leaves in a tree are finite. Args: tree_like: pytree with array leaves Raises: AssertionError: if any leaf in the tree is non-finite. """ all_finite = jax.tree_util.tree_all( jax.tree_map(lambda x: jnp.all(jnp.isfinite(x)), tree_like)) if not all_finite: is_finite = lambda x: "Finite" if jnp.all(jnp.isfinite(x)) else "Nonfinite" error_msg = jax.tree_map(is_finite, tree_like) raise AssertionError(f"Tree contains non-finite value: {error_msg}.")
def testDerivativeIsBoundedWhenAlphaIsBelow1(self): # Assert that |d_x| < 1/scale when alpha <= 1. _, _, _, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs() mask = jnp.isfinite(alpha) & (alpha <= 1) grad = jnp.abs(d_x[mask]) bound = ((1. + (300. * jnp.finfo(jnp.float32).eps)) / scale[mask]) self.assertTrue(jnp.all(grad <= bound))
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def sample(self, data=None, name=None, shape=None, obs=None): '''Sample responses''' name = name or self.name if data is None: X = self.X # use same data used to create model else: info = self.X.design_info # information from original data X = patsy.dmatrix(info, data) # design matrix for new data linpred = np.array(X) @ self.theta if shape is not None: linpred = linpred.reshape(shape) # reshape to tensor if requested fwd, inv = self.link() if self.guess is None: mu = inv(linpred) else: fwd_guess = fwd(self.guess) if not np.isfinite(fwd_guess): raise ValueError("Bad Guess") mu = inv(fwd_guess + linpred) y = numpyro.sample(name, self.family(mu), obs=obs) return y, mu, linpred
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) # set base value of x_guide is 0.9 x_base = 0.9 guide = substitute(guide, base_param_map={'x': x_base}) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(0), (), ()) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) x, _ = x_guide.transform_with_intermediates(x_base) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64), 1 + state.notfinite_count) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond(jnp.logical_or( isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState( notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite, inner_state=new_inner_state)
def _observe_binom_approx(name, latent, det_rate, det_conc, obs=None): '''Make observations of a latent variable using BinomialApprox.''' mask = True # Regularization: add reg to observed, and (reg/det_rate) to latent # The primary purpose is to avoid zeros, which are invalid values for # the Beta observation model. reg = 0.5 latent = latent + (reg / det_rate) if obs is not None: ''' Workaround for a jax issue: substitute default values AND mask out bad observations. See https://forum.pyro.ai/t/behavior-of-mask-handler-with-invalid-observation-possible-bug/1719/5 ''' mask = np.isfinite(obs) obs = np.where(mask, obs, 0.5 * latent) obs = obs + reg det_rate = np.broadcast_to(det_rate, latent.shape) det_conc = np.minimum( det_conc, latent) # don't allow it to be *more* concentrated than Binomial d = BinomialApprox(latent + (reg / det_rate), det_rate, det_conc) with numpyro.handlers.mask(mask_array=mask): y = numpyro.sample(name, d, obs=obs) return y
def tmrca_sf(t: np.ndarray, y: np.ndarray, n: int) -> np.ndarray: """The survival function of the TMRCA at each time point Args: t: time grid (including zero and infinity) y: effective population size in each epoch n: number of sampled haplotypes """ # epoch durations s = np.diff(t) logu = -s / y logu = np.concatenate((np.array([0]), logu)) # the A_2j are the product of this matrix # NOTE: using letter "l" as a variable name to match text l = onp.arange(2, n + 1)[:, onp.newaxis] # noqa: E741 with onp.errstate(divide='ignore'): A2_terms = l * (l - 1) / (l * (l - 1) - l.T * (l.T - 1)) onp.fill_diagonal(A2_terms, 1) A2 = np.prod(A2_terms, axis=0) binom_vec = l * (l - 1) / 2 result = np.zeros(len(t)) result = index_update(result, index[:-1], np.squeeze(A2[np.newaxis, :] @ np.exp(np.cumsum(logu[np.newaxis, :-1], axis=1)) ** binom_vec)) assert np.all(np.isfinite(result)) return result
def get_initial_state(system, rng, generate_x_obs_seq_init, dim_q, tol, adam_step_size=2e-1, reg_coeff=5e-2, coarse_tol=1e-1, max_iters=1000, max_num_tries=10): """Find an initial constraint satisying state. Uses a heuristic combination of gradient-based minimisation of the norm of a modified constraint function plus a subsequent projection step using a quasi-Newton method, to try to find an initial point `q` such that `max(abs(constr(q)) < tol`. """ # Use optimizers to set optimizer initialization and update functions opt_init, opt_update, get_params = opt.adam(adam_step_size) # Define a compiled update step @api.jit def step(i, opt_state, x_obs_seq_init): q, = get_params(opt_state) (obj, constr), grad = system.value_and_grad_init_objective( q, x_obs_seq_init, reg_coeff) opt_state = opt_update(i, grad, opt_state) return opt_state, obj, constr for t in range(max_num_tries): logging.info(f'Starting try {t+1}') q_init = rng.standard_normal(dim_q) x_obs_seq_init = generate_x_obs_seq_init(rng) opt_state = opt_init((q_init, )) for i in range(max_iters): opt_state_next, norm, constr = step(i, opt_state, x_obs_seq_init) if not np.isfinite(norm): logger.info('Adam iteration diverged') break max_abs_constr = maximum_norm(constr) if max_abs_constr < coarse_tol: logging.info('Within coarse_tol attempting projection.') q_init, = get_params(opt_state) state = ConditionedDiffusionHamiltonianState( q_init, x_obs_seq=x_obs_seq_init) try: state = jitted_solve_projection_onto_manifold_quasi_newton( state, state, 1., system, tol) except ConvergenceError: logger.info('Quasi-Newton iteration diverged.') if np.max(np.abs(system.constr(state))) < tol: logging.info('Found constraint satisfying state.') state.mom = system.sample_momentum(state, rng) return state if i % 100 == 0: logging.info(f'Iteration {i: >6}: mean|constr|^2 = {norm:.3e} ' f'max|constr| = {max_abs_constr:.3e}') opt_state = opt_state_next raise RuntimeError(f'Did not find valid state in {max_num_tries} tries.')
def body(state): p_k = -(state.H_k @ state.g_k) line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k, maxiter=ls_maxiter) state = state._replace(nfev=state.nfev + line_search_results.nfev, ngev=state.ngev + line_search_results.ngev, failed=line_search_results.failed, ls_status=line_search_results.status) s_k = line_search_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = line_search_results.f_k g_kp1 = line_search_results.g_k # print(g_kp1) y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(y_k @ s_k) sy_k = s_k[:, None] * y_k[None, :] w = jnp.eye(d) - rho_k * sy_k H_kp1 = jnp.where(jnp.isfinite(rho_k), jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol state = state._replace(converged=converged, k=state.k + 1, x_k=x_kp1, f_k=f_kp1, g_k=g_kp1, H_k=H_kp1 ) return state
def cond_fn(iteration, const, state): threshold = const[-1] errors = state[0] err = errors[iteration // inner_iterations-1, 0] return jnp.logical_or(iteration == 0, jnp.logical_and(jnp.isfinite(err), err > threshold))
def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon): """Tests the discrete barycenters on a 5x5x5 grid. Puts two masses on opposing ends of the hypercube with small noise in between. Check that their W barycenter sits (mostly) at the middle of the hypercube (e.g. index (5x5x5-1)/2) Args: lse_mode: bool, lse or scaling computations. debiased: bool, use (or not) debiasing as proposed in https://arxiv.org/abs/2006.02575 epsilon: float, regularization parameter """ size = jnp.array([5, 5, 5]) grid_3d = grid.Grid(grid_size=size, epsilon=epsilon) a = jnp.ones(size) b = jnp.ones(size) a = a.ravel() b = b.ravel() a = jax.ops.index_update(a, 0, 10000) b = jax.ops.index_update(b, -1, 10000) a = a / jnp.sum(a) b = b / jnp.sum(b) threshold = 1e-2 _, _, bar, errors = db.discrete_barycenter(grid_3d, a=jnp.stack((a, b)), threshold=threshold, lse_mode=lse_mode, debiased=debiased) self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7) self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2]) err = errors[jnp.isfinite(errors)][-1] self.assertGreater(threshold, err)
def _transform(x: DeviceArray, bounds: Optional[DeviceArray]) -> DeviceArray: if bounds is None: return x low = bounds[:, 0] upp = bounds[:, 1] return jnp.where( jnp.isfinite(low) & jnp.isfinite(upp), _between(x, low, upp), jnp.where( jnp.isfinite(low), _greater_than(x, low), jnp.where(jnp.isfinite(upp), _less_than(x, upp), x), ), )
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum() transfrom = transforms.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_near_singular_inverse(self, jit): rng = jtu.rand_default(self.rng()) @partial(_maybe_jit, jit, static_argnums=1) def near_singular_inverse(N=5, eps=1E-40): X = rng((N, N), dtype='float64') X = jnp.asarray(X) X = X.at[-1].mul(eps) return jnp.linalg.inv(X) with enable_x64(): result_64 = near_singular_inverse() self.assertTrue(jnp.all(jnp.isfinite(result_64))) with disable_x64(): result_32 = near_singular_inverse() self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
def _max_mask_non_finite(x, axis=-1, keepdims=False, mask=0): """Returns `max` or `mask` if `max` is not finite.""" m = np.max(x, axis=_astuple(axis), keepdims=keepdims) needs_masking = ~np.isfinite(m) if needs_masking.ndim > 0: m = np.where(needs_masking, mask, m) elif needs_masking: m = mask return m
def eval_and_stable_update(self, fn: Callable, state: _IterOptState) -> _IterOptState: """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` and will set the objective output to `nan`. :param fn: objective function. :param state: current optimizer state. :return: a pair of the output of objective function and the new optimizer state. """ params = self.get_params(state) out, grads = value_and_grad(fn)(params) out, state = lax.cond( jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(), lambda _: (out, self.update(grads, state)), lambda _: (jnp.nan, state), None) return out, state
def categorical_sample(key, probs): """Sample from a set of discrete probabilities.""" probs = probs / probs.sum(axis=-1, keepdims=True) is_valid = jnp.logical_and(jnp.all(jnp.isfinite(probs)), jnp.all(probs >= 0)) cpi = jnp.cumsum(probs, axis=-1) eps = jnp.finfo(probs.dtype).eps rnds = jax.random.uniform( key=key, shape=probs.shape[:-1] + (1,), dtype=probs.dtype, minval=eps) argmin = jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1) return jnp.where(is_valid, argmin, -1)
def log_posterior(theta): log_prior_val = uniform_log_prior(theta) if np.isfinite(log_prior_val): preds = get_predictions(theta) log_lik_val = np.sum( log_likelihood_as_fxn_of_prediction(preds, expt_means, expt_uncertainties)) return log_prior_val + log_lik_val else: return -np.inf
def test_near_singular_inverse(self, jit): if jtu.device_under_test() == "tpu": self.skipTest("64-bit inverse not available on TPU") @partial(_maybe_jit, jit, static_argnums=1) def near_singular_inverse(key, N, eps): X = random.uniform(key, (N, N)) X = X.at[-1].mul(eps) return jnp.linalg.inv(X) key = random.PRNGKey(1701) eps = 1E-40 N = 5 with enable_x64(): result_64 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(jnp.isfinite(result_64))) with disable_x64(): result_32 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(~jnp.isfinite(result_32)))