def bc_flat(aug_t0, aug_t1): error_flat, _ = ravel_pytree(bc(unravel(aug_t0), unravel(aug_t1))) return error_flat
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # XXX: we don't want to apply enum to draw latent samples model_ = model if enum: from numpyro.contrib.funsor import enum as enum_handler if isinstance(model, substitute) and isinstance( model.fn, enum_handler): model_ = substitute(model.fn.fn, data=model.data) elif isinstance(model, enum_handler): model_ = model.fn # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace( *model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if (v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete): constrained_values[k] = v["value"] with helpful_support_errors(v): inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid
def testEmpty(self): tree = [] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.float32) # convention tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup_steps: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve `target_acceptance_prob`. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = float(trajectory_length) max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., wa_state, rng_hmc) if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = jit(fori_loop, static_argnums=(2, ))( 0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = sample_kernel(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state
def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, **progbar_opts): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via `transform`, and progress bar updates. Note that, `progbar=False` will be faster, especially when collecting a lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`. :param int lower: the index to start the collective work. In other words, we will skip collecting the first `lower` values. :param int upper: number of times to run the loop body. :param body_fun: a callable that takes a collection of `np.ndarray` and returns a collection with the same shape and `dtype`. :param init_val: initial value to pass as argument to `body_fun`. Can be any Python collection type containing `np.ndarray` objects. :param transform: a callable to post-process the values returned by `body_fn`. :param progbar: whether to post progress bar updates. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress bar postfix. Also a `progbar_desc` keyword argument can be supplied which is used to label the progress bar. :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ assert lower < upper init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731 if not progbar: collection = np.zeros((upper - lower, ) + init_val_flat.shape) def _body_fn(i, vals): val, collection = vals val = body_fun(val) i = np.where(i >= lower, i - lower, 0) collection = ops.index_update(collection, i, ravel_fn(val)) return val, collection _, collection = jit(fori_loop, static_argnums=(2, ))(0, upper, _body_fn, (init_val, collection)) else: diagnostics_fn = progbar_opts.pop('diagnostics_fn', None) progbar_desc = progbar_opts.pop('progbar_desc', '') collection = [] val = init_val with tqdm.trange(upper, desc=progbar_desc) as t: for i in t: val = body_fun(val) if i >= lower: collection.append(jit(ravel_fn)(val)) if diagnostics_fn: t.set_postfix_str(diagnostics_fn(val), refresh=False) # XXX: jax.numpy.stack/concatenate is currently slow collection = onp.stack(collection) return vmap(unravel_fn)(collection)
def augmented_post_drift(rng, state_unravel_fn, param_unravel_fn, driftx_apply_fn, prior_driftw_apply_fn, driftw0_apply_fn, driftw_apply_fn, diff_apply_fn, W_driftwt, W_priorwt, W_diffusion, W0_post, setting, t, aug, eps, **kwargs): # last 3 partial """aug: [x, w, kl] """ flat_w, x, kl = state_unravel_fn(aug) # flat_w inputs w = param_unravel_fn(flat_w) # hierarchical w params _, dxt = driftx_apply_fn( w, (t, x), ** kwargs) # TODO: does this need to be bayesian? No we sample once only. _, dwtkl = driftw_apply_fn(W_driftwt, (t, flat_w), **kwargs) # pw_drift = 0. # assume prior has zero drift _, pw_drift = prior_driftw_apply_fn(W_priorwt, (t, flat_w), **kwargs) # account for prior _, diffwt = diff_apply_fn(W_diffusion, (t, flat_w), **kwargs) # with time dependence # stop gradient only to compute the KL _dwtkl = dwtkl # already "includes the w0kl", no need to add again. u = (_dwtkl - pw_drift) * 0 if setting['diffw']: u = (_dwtkl - pw_drift) / diffwt dkl = np.sum((u**2) / 2) # dt = 1 u2 = u * eps dkl_stl = np.sum(u2) # _w0kl, _w0_prior = 0, 0 # for gaussian mixture on W0 if setting['stop_grad']: print("applying sticking the landing") if setting['priorw0_sigma'] >= 0: print("stopping gradient on w0 posterior") # stop grad on posterior over w0 W0_post = tree_map(lambda w: lax.stop_gradient(w), W0_post) # _w0kl = tree_map(lambda w: lax.stop_gradient(w), _w0kl) # TODO: alternative? # w0out, _w0, _w0kl = driftw0_apply_fn(W0_post, (0, x), rng=rng, logsigma2=np.log(setting['priorw0_sigma'])) # approx posterior over weights if setting['diff_drift']: print("stopping gradient on difference parameterization") W_driftwtkl = tree_map(lambda w: lax.stop_gradient(w), W_driftwt[-1]) W_driftwtkl = (W_driftwt[0], W_driftwtkl ) # first slot is params for t, which are () _, _dwtkl = driftw_apply_fn(W_driftwtkl, (t, flat_w), **kwargs) # dwtkl else: W_driftwtkl = tree_map(lambda w: lax.stop_gradient(w), W_driftwt) _, _dwtkl = driftw_apply_fn(W_driftwtkl, (t, flat_w), **kwargs) # dwtkl # for original (but wrong) formulation of STL # u = (_dwtkl - pw_drift) * 0 # if setting['diffw']: # u = (_dwtkl - pw_drift) / diffwt # dkl = np.sum((u ** 2) / 2) ## dkl = dkl + _w0kl # account for bayesian w0 parameters # New STL: second term stop gradient (not first term - moved above) if setting['stop_grad']: print("stopping gradient on second STL term") u2 = (_dwtkl - pw_drift) * 0 if setting['diffw']: u2 = (_dwtkl - pw_drift) / diffwt u2 = u2 * eps # u dBt dkl_stl = np.sum(u2) dkl = dkl + dkl_stl # outputs = [dxt, dwtkl, dkl] outputs = [dwtkl, dxt, dkl] return ravel_pytree(outputs)[0]
def _initialize_mass_matrix(z, inverse_mass_matrix, dense_mass): if isinstance(dense_mass, list): if inverse_mass_matrix is None: inverse_mass_matrix = {} # if user specifies an ndarray mass matrix, then we convert it to a dict elif not isinstance(inverse_mass_matrix, dict): inverse_mass_matrix = {tuple(sorted(z)): inverse_mass_matrix} mass_matrix_sqrt = {} mass_matrix_sqrt_inv = {} for site_names in dense_mass: inverse_mm = inverse_mass_matrix.get(site_names) z_block = tuple(z[k] for k in site_names) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, inverse_mm, True ) inverse_mass_matrix[site_names] = inverse_mm mass_matrix_sqrt[site_names] = mm_sqrt mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv # NB: this branch only happens when users want to use block diagonal # inverse_mass_matrix, for example, {("a",): jnp.ones(3), ("b",): jnp.ones(3)}. for site_names, inverse_mm in inverse_mass_matrix.items(): if site_names in dense_mass: continue z_block = tuple(z[k] for k in site_names) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, inverse_mm, False ) inverse_mass_matrix[site_names] = inverse_mm mass_matrix_sqrt[site_names] = mm_sqrt mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv remaining_sites = tuple(sorted(set(z) - set().union(*inverse_mass_matrix))) if len(remaining_sites) > 0: z_block = tuple(z[k] for k in remaining_sites) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, None, False ) inverse_mass_matrix[remaining_sites] = inverse_mm mass_matrix_sqrt[remaining_sites] = mm_sqrt mass_matrix_sqrt_inv[remaining_sites] = mm_sqrt_inv expected_site_names = sorted(z) actual_site_names = sorted( [k for site_names in inverse_mass_matrix for k in site_names] ) assert actual_site_names == expected_site_names, ( "There seems to be a conflict of sites names specified in the initial" " `inverse_mass_matrix` and in `dense_mass` argument." ) return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv mass_matrix_size = jnp.size(ravel_pytree(z)[0]) if inverse_mass_matrix is None: if dense_mass: inverse_mass_matrix = jnp.identity(mass_matrix_size) else: inverse_mass_matrix = jnp.ones(mass_matrix_size) mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix else: if dense_mass: if jnp.ndim(inverse_mass_matrix) == 1: inverse_mass_matrix = jnp.diag(inverse_mass_matrix) mass_matrix_sqrt_inv = jnp.swapaxes( jnp.linalg.cholesky(inverse_mass_matrix[..., ::-1, ::-1])[ ..., ::-1, ::-1 ], -2, -1, ) identity = jnp.identity(inverse_mass_matrix.shape[-1]) mass_matrix_sqrt = solve_triangular( mass_matrix_sqrt_inv, identity, lower=True ) else: if jnp.ndim(inverse_mass_matrix) == 2: inverse_mass_matrix = jnp.diag(inverse_mass_matrix) mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix) mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv) return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv
def fista_step(data, loss_and_prox_op, model_param, options): """Fista optimization step for solving regularized problem. Args: data: A tuple of inputs and labels passed to the loss function. loss_and_prox_op: Tuple of (loss_f, prox_g) loss_f is the loss function that takes in model_param, inputs, and labels. prox_g is the proximity operator for g. model_param: Current model parameters to be passed to loss_f. options: A dictionary of optimizer specific hyper-parameters. Returns: Updated model parameters and updated step size. """ options = dict(options) step_size = options.get('step_size', 1.0) acceleration = options.get('acceleration', True) t = options.get('t', 1.0) verbose = options.get('verbose', False) reuse_last_step = options.get('reuse_last_step', False) loss_f, prox_g = loss_and_prox_op inputs, labels = data[0], data[1] fun_f = lambda param: loss_f(param, inputs, labels) value_and_grad_f = jax.value_and_grad(fun_f) x, unravel_fn = ravel_pytree(model_param) y = options.get('y', x) value_f, grad_f = value_and_grad_f(unravel_fn(y)) grad_f, unravel_fn = ravel_pytree(grad_f) def next_candidate(step_size): return prox_g(y - grad_f * step_size, step_size) def stop_cond(step_size, next_iter): diff = next_iter - y sqdist = jnp.sum(diff**2) # We do not compute the non-smooth term (g in the paper) # as it cancels out from value_F and value_Q. value_bigf = fun_f(next_iter) value_bigq = value_f + jnp.sum( diff * grad_f) + 0.5 / step_size * sqdist return value_bigf <= value_bigq x_old = x step_size, x = backtracking(next_candidate, stop_cond, step_size, options) # Acceleration. if acceleration: t_next = (1 + jnp.sqrt(1 + 4 * t**2)) / 2. y = x + (t - 1) / t_next * (x - x_old) t = t_next options['y'] = y options['t'] = t else: y = x if reuse_last_step: options['step_size'] = step_size if verbose: logging.info('Step size: %f', step_size) return unravel_fn(x), options
def gradient_descent_line_search_step(data, loss_f, model_param, options): """Gradient Descent optimization with line search step. Args: data: A tuple of inputs and labels passed to the loss function. loss_f: The loss function that takes in model_param, inputs, and labels. model_param: Current model parameters to be passed to loss_f. options: A dictionary of optimizer specific hyper-parameters. Returns: Updated model parameters and updated step size. """ options = dict(options) beta = options.get('beta', 0.9) beta_prime = options.get('beta_prime', 1e-4) step_size = options.get('step_size', 10000.0) verbose = options.get('verbose', False) reuse_last_step = options.get('reuse_last_step', False) inputs, labels = data[0], data[1] loss_with_data_f = lambda param: loss_f(param, inputs, labels) value_and_grad_f = jax.value_and_grad(loss_with_data_f) value, grad = value_and_grad_f(model_param) # Maximum learning rate allowed from Theorem 5 in Gunasekar et al. 2017 if options['bound_step']: # Bound by dual of L2 b_const = jnp.max(jnp.linalg.norm(inputs, ord=2, axis=0)) step_size = min(step_size, 1 / (b_const * b_const * value)) grad, unravel_fn = ravel_pytree(grad) x, unravel_fn = ravel_pytree(model_param) # If we normalize step_size will be harder to tune. direction = -grad # TODO(fartash): consider using the condition in FISTA def next_candidate(step_size): next_iter = x + step_size * direction next_value, next_grad = value_and_grad_f(unravel_fn(next_iter)) next_grad, _ = ravel_pytree(next_grad) return next_iter, next_value, next_grad def stop_cond(step_size, res): _, next_value, next_grad = res gd = jnp.sum(grad * direction) # Strong Wolfe condition. cond1 = next_value <= value + beta_prime * step_size * gd cond2 = jnp.sum(jnp.abs(next_grad * direction)) >= beta * gd return cond1 and cond2 step_size, res = backtracking(next_candidate, stop_cond, step_size, options=options) next_param = res[0] if reuse_last_step: options['step_size'] = step_size if verbose: logging.info('Step size: %f', step_size) return unravel_fn(next_param), options
def update(i, g, w): g_flat, unflatten = ravel_pytree(g) w_flat = ravel_pytree_jit(w) updated_params = soft_thresholding(w_flat - step_size * g_flat, step_size * lambd) return unflatten(updated_params)
def next_candidate(step_size): next_iter = x + step_size * direction next_value, next_grad = value_and_grad_f(unravel_fn(next_iter)) next_grad, _ = ravel_pytree(next_grad) return next_iter, next_value, next_grad
g_flat, unflatten = ravel_pytree(g) w_flat = ravel_pytree_jit(w) updated_params = soft_thresholding(w_flat - step_size * g_flat, step_size * lambd) return unflatten(updated_params) def get_params(w): return w def set_step_size(lr): step_size = lr return init, update, get_params, soft_thresholding, set_step_size ravel_pytree_jit = jit(lambda tree: ravel_pytree(tree)[0]) @jit def line_search(w, g, batch, beta): lr_i = 1 g_flat, unflatten_g = ravel_pytree(g) w_flat = ravel_pytree_jit(w) z_flat = soft_thresholding(w_flat - lr_i * g_flat, lr_i * lambd) z = unflatten_g(z_flat) for i in range(20): is_converged = loss( z, batch) > loss(w, batch) + g_flat @ (z_flat - w_flat) + np.sum( (z_flat - w_flat)**2) / (2 * lr_i) lr_i = jnp.where(is_converged, lr_i, beta * lr_i) return lr_i
def ravel_first_arg_(unravel, y_flat, *args): y = unravel(y_flat) ans = yield (y, ) + args, {} ans_flat, _ = ravel_pytree(ans) yield ans_flat
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args): y0, unravel = ravel_pytree(y0) func = ravel_first_arg(func, unravel) out = _odeint(func, rtol, atol, mxstep, y0, ts, *args) return jax.vmap(unravel)(out)
def aug_init(W0, _inputs): # partial over last flat_aug, _ = ravel_pytree([W0, _inputs, 0.]) flat_w_dim = len(ravel_pytree([W0])[0]) return flat_aug, flat_w_dim
def _weight_fn(params): flat_params, _ = ravel_pytree(params) return 0.5 * jnp.sum(jnp.square(flat_params))
def apply_fun(params, inputs, rng, **kwargs): """(aka predict_apply_fn) Args: inputs: x_t0 ... x_tn params: (w0, w_driftwt, w_difft) Return: x_t1, log p(x) / q(x) """ W0, W_driftwt, W_diffusion, W_priorwt, W0_post = params # _, W0 = driftx_init_fn(rng, (-1, x_dim)) # W0 is randomized each call # Bayesian about W0 (mean field) W0kl = 0. if W0_post is not None: tw0 = 0 # NOTE: this is just a dummy variable to be compatible with `shape_dependent` _, W0, W0kl = driftw0_apply_fn(W0_post, (tw0, inputs), rng=rng, logsigma2=np.log( setting['priorw0_sigma'])) flat_W0, param_unravel_fn = ravel_pytree(W0) _, state_unravel_fn = ravel_pytree([flat_W0, inputs, W0kl]) aug_init_ = lambda inputs_: aug_init(W0, inputs_) aug_post_drift = lambda t_, aug_, eps_: augmented_post_drift( rng, state_unravel_fn, param_unravel_fn, driftx_apply_fn, prior_driftw_apply_fn, driftw0_apply_fn, driftw_apply_fn, diff_apply_fn, W_driftwt, W_priorwt, W_diffusion, W0_post, setting, t_, aug_, eps_, **kwargs) aug_prior_drift = lambda t_, aug_, eps_: augmented_prior_drift( state_unravel_fn, param_unravel_fn, driftx_apply_fn, prior_driftw_apply_fn, diff_apply_fn, W_priorwt, W_diffusion, setting, t_, aug_, eps_, **kwargs) aug_diffusion = lambda t_, aug_: augmented_diffusion( state_unravel_fn, param_unravel_fn, diff_apply_fn, W_diffusion, t_, aug_, **kwargs) if kwargs.get('entire', False): print("returning entire trajectory...") # outputs = sdeint(aug_post_drift, aug_diffusion, aug_init_(inputs), np.linspace(0, 1, 100), rng) outputs = sdeint(aug_post_drift, aug_diffusion, aug_init_(inputs)[0], aug_init_(inputs)[1], np.linspace(0, 1, 90), rng) # TODO: to prevent oom problems ws, xs, kls = zip(*[state_unravel_fn(out) for out in outputs ]) # list(x_dim), list(w_dim), list(1) # prior_outputs = sdeint(aug_prior_drift, aug_diffusion, aug_init_(inputs), np.linspace(0, 1, 100), rng) prior_outputs = sdeint(aug_prior_drift, aug_diffusion, aug_init_(inputs)[0], aug_init_(inputs)[1], np.linspace(0, 1, 90), rng) # TODO: 100 steps prior_ws, prior_xs, prior_kls = zip(*[ state_unravel_fn(out) for out in prior_outputs ]) # list(x_dim), list(w_dim), list(1) return xs, ws, kls, prior_xs, prior_ws, prior_kls solution = _sdeint(aug_post_drift, aug_diffusion, aug_init_(inputs)[0], aug_init_(inputs)[1], np.linspace(0, 1, 20), rng) # was sdeint[-1] w, x, kl = state_unravel_fn(solution) prior_outputs = _sdeint( aug_prior_drift, aug_diffusion, aug_init_(inputs)[0], aug_init_(inputs)[1], np.linspace(0, 1, 20), rng) # TODO: ramp up 100 steps for better performance prior_ws, prior_xs, prior_kls = state_unravel_fn(prior_outputs) return x, w, kl, prior_xs, prior_ws, prior_kls
def scipy_minimize_with_jax(fun, x0, method=None, args=(), bounds=None, constraints=(), tol=None, callback=None, options=None): """ A simple wrapper for scipy.optimize.minimize using JAX. Parameters ---------- fun: function The objective function to be minimized, written in JAX code so that it is automatically differentiable. It is of type, ```fun: x, *args -> float``` where `x` is a PyTree and args is a tuple of the fixed parameters needed to completely specify the function. x0: jnp.ndarray Initial guess represented as a JAX PyTree. args: tuple, optional. Extra arguments passed to the objective function and its derivative. Must consist of valid JAX types; e.g. the leaves of the PyTree must be floats. method : str or callable, optional Type of solver. Should be one of - 'Nelder-Mead' :ref:`(see here) <optimize.minimize-neldermead>` - 'Powell' :ref:`(see here) <optimize.minimize-powell>` - 'CG' :ref:`(see here) <optimize.minimize-cg>` - 'BFGS' :ref:`(see here) <optimize.minimize-bfgs>` - 'Newton-CG' :ref:`(see here) <optimize.minimize-newtoncg>` - 'L-BFGS-B' :ref:`(see here) <optimize.minimize-lbfgsb>` - 'TNC' :ref:`(see here) <optimize.minimize-tnc>` - 'COBYLA' :ref:`(see here) <optimize.minimize-cobyla>` - 'SLSQP' :ref:`(see here) <optimize.minimize-slsqp>` - 'trust-constr':ref:`(see here) <optimize.minimize-trustconstr>` - 'dogleg' :ref:`(see here) <optimize.minimize-dogleg>` - 'trust-ncg' :ref:`(see here) <optimize.minimize-trustncg>` - 'trust-exact' :ref:`(see here) <optimize.minimize-trustexact>` - 'trust-krylov' :ref:`(see here) <optimize.minimize-trustkrylov>` - custom - a callable object (added in version 0.14.0), see below for description. If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, depending on if the problem has constraints or bounds. bounds : sequence or `Bounds`, optional Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and trust-constr methods. There are two ways to specify the bounds: 1. Instance of `Bounds` class. 2. Sequence of ``(min, max)`` pairs for each element in `x`. None is used to specify no bound. Note that in order to use `bounds` you will need to manually flatten them in the same order as your inputs `x0`. constraints : {Constraint, dict} or List of {Constraint, dict}, optional Constraints definition (only for COBYLA, SLSQP and trust-constr). Constraints for 'trust-constr' are defined as a single object or a list of objects specifying constraints to the optimization problem. Available constraints are: - `LinearConstraint` - `NonlinearConstraint` Constraints for COBYLA, SLSQP are defined as a list of dictionaries. Each dictionary with fields: type : str Constraint type: 'eq' for equality, 'ineq' for inequality. fun : callable The function defining the constraint. jac : callable, optional The Jacobian of `fun` (only for SLSQP). args : sequence, optional Extra arguments to be passed to the function and Jacobian. Equality constraint means that the constraint function result is to be zero whereas inequality means that it is to be non-negative. Note that COBYLA only supports inequality constraints. Note that in order to use `constraints` you will need to manually flatten them in the same order as your inputs `x0`. tol : float, optional Tolerance for termination. For detailed control, use solver-specific options. options : dict, optional A dictionary of solver options. All methods accept the following generic options: maxiter : int Maximum number of iterations to perform. Depending on the method each iteration may use several function evaluations. disp : bool Set to True to print convergence messages. For method-specific options, see :func:`show_options()`. callback : callable, optional Called after each iteration. For 'trust-constr' it is a callable with the signature: ``callback(xk, OptimizeResult state) -> bool`` where ``xk`` is the current parameter vector represented as a PyTree, and ``state`` is an `OptimizeResult` object, with the same fields as the ones from the return. If callback returns True the algorithm execution is terminated. For all the other methods, the signature is: ```callback(xk)``` where `xk` is the current parameter vector, represented as a PyTree. Returns ------- res : The optimization result represented as a ``OptimizeResult`` object. Important attributes are: ``x``: the solution array, represented as a JAX PyTree ``success``: a Boolean flag indicating if the optimizer exited successfully ``message``: describes the cause of the termination. See `scipy.optimize.OptimizeResult` for a description of other attributes. """ if soptimize is None: raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' f'function: {scipy_minimize_with_jax}') # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays x0_flat, unravel = ravel_pytree(x0) # Wrap the objective function to consume flat _original_ # numpy arrays and produce scalar outputs. def fun_wrapper(x_flat, *args): x = unravel(x_flat) r = fun(x, *args) r = r.value if isinstance(r, bm.JaxArray) else r return float(r) # Wrap the gradient in a similar manner jac = jit(grad(fun)) def jac_wrapper(x_flat, *args): x = unravel(x_flat) g_flat, _ = ravel_pytree(jac(x, *args)) return np.array(g_flat) # Wrap the callback to consume a pytree def callback_wrapper(x_flat, *args): if callback is not None: x = unravel(x_flat) return callback(x, *args) # Minimize with scipy results = soptimize.minimize(fun_wrapper, x0_flat, args=args, method=method, jac=jac_wrapper, callback=callback_wrapper, bounds=bounds, constraints=constraints, tol=tol, options=options) # pack the output back into a PyTree results["x"] = unravel(results["x"]) return results
def build_tree( verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=1000.0, max_tree_depth=10, ): """ Builds a binary tree from the `verlet_state`. This is used in NUTS sampler. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman 2. *A Conceptual Introduction to Hamiltonian Monte Carlo*, Michael Betancourt :param verlet_update: A callable to get a new integrator state given a current integrator state. :param kinetic_fn: A callable to compute kinetic energy. :param verlet_state: Initial integrator state. :param inverse_mass_matrix: Inverse of the mass matrix. :param float step_size: Step size for the current trajectory. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. :param float max_delta_energy: A threshold to decide if the new state diverges (based on the energy difference) too much from the initial integrator state. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers `(d1, d2)`, where `d1` is the max tree depth at the current MCMC step and `d2` is the global max tree depth for all MCMC steps. :return: information of the tree. :rtype: :data:`TreeInfo` """ if isinstance(max_tree_depth, tuple): max_tree_depth_current, max_tree_depth = max_tree_depth else: max_tree_depth_current = max_tree_depth z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) latent_size = jnp.size(ravel_pytree(r)[0]) r_ckpts = jnp.zeros((max_tree_depth, latent_size)) r_sum_ckpts = jnp.zeros((max_tree_depth, latent_size)) tree = TreeInfo( z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current, depth=0, weight=jnp.zeros(()), r_sum=r, turning=jnp.array(False), diverging=jnp.array(False), sum_accept_probs=jnp.zeros(()), num_proposals=jnp.array(0, dtype=jnp.result_type(int)), ) def _cond_fn(state): tree, _ = state return (tree.depth < max_tree_depth_current) & ~tree.turning & ~tree.diverging def _body_fn(state): tree, key = state key, direction_key, doubling_key = random.split(key, 3) going_right = random.bernoulli(direction_key) tree = _double_tree( tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, doubling_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts, ) return tree, key state = (tree, rng_key) tree, _ = while_loop(_cond_fn, _body_fn, state) return tree
def jac_wrapper(x_flat, *args): x = unravel(x_flat) g_flat, _ = ravel_pytree(jac(x, *args)) return np.array(g_flat)
def fori_collect( lower, upper, body_fun, init_val, transform=identity, progbar=True, return_last_val=False, collection_size=None, thinning=1, **progbar_opts, ): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via `transform`, and progress bar updates. Note that, `progbar=False` will be faster, especially when collecting a lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`. :param int lower: the index to start the collective work. In other words, we will skip collecting the first `lower` values. :param int upper: number of times to run the loop body. :param body_fun: a callable that takes a collection of `np.ndarray` and returns a collection with the same shape and `dtype`. :param init_val: initial value to pass as argument to `body_fun`. Can be any Python collection type containing `np.ndarray` objects. :param transform: a callable to post-process the values returned by `body_fn`. :param progbar: whether to post progress bar updates. :param bool return_last_val: If `True`, the last value is also returned. This has the same type as `init_val`. :param thinning: Positive integer that controls the thinning ratio for retained values. Defaults to 1, i.e. no thinning. :param int collection_size: Size of the returned collection. If not specified, the size will be ``(upper - lower) // thinning``. If the size is larger than ``(upper - lower) // thinning``, only the top ``(upper - lower) // thinning`` entries will be non-zero. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress bar postfix. Also a `progbar_desc` keyword argument can be supplied which is used to label the progress bar. :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ assert lower <= upper assert thinning >= 1 collection_size = ((upper - lower) // thinning if collection_size is None else collection_size) assert collection_size >= (upper - lower) // thinning init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) start_idx = lower + (upper - lower) % thinning num_chains = progbar_opts.pop("num_chains", 1) # host_callback does not work yet with multi-GPU platforms # See: https://github.com/google/jax/issues/6447 if num_chains > 1 and jax.default_backend() == "gpu": warnings.warn( "We will disable progress bar because it does not work yet on multi-GPUs platforms." ) progbar = False @cached_by(fori_collect, body_fun, transform) def _body_fn(i, vals): val, collection, start_idx, thinning = vals val = body_fun(val) idx = (i - start_idx) // thinning collection = cond( idx >= 0, collection, lambda x: x.at[idx].set(ravel_pytree(transform(val))[0]), collection, identity, ) return val, collection, start_idx, thinning collection = jnp.zeros((collection_size, ) + init_val_flat.shape) if not progbar: last_val, collection, _, _ = fori_loop( 0, upper, _body_fn, (init_val, collection, start_idx, thinning)) elif num_chains > 1: progress_bar_fori_loop = progress_bar_factory(upper, num_chains) _body_fn_pbar = progress_bar_fori_loop(_body_fn) last_val, collection, _, _ = fori_loop( 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)) else: diagnostics_fn = progbar_opts.pop("diagnostics_fn", None) progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "") vals = (init_val, collection, device_put(start_idx), device_put(thinning)) if upper == 0: # special case, only compiling jit(_body_fn)(0, vals) else: with tqdm.trange(upper) as t: for i in t: vals = jit(_body_fn)(i, vals) t.set_description(progbar_desc(i), refresh=False) if diagnostics_fn: t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False) last_val, collection, _, _ = vals unravel_collection = vmap(unravel_fn)(collection) return (unravel_collection, last_val) if return_last_val else unravel_collection
def __call__(self, params): param_flat, _ = flatten_util.ravel_pytree(params) return self.initialize(param_flat)
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split( sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal( locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + ( accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate): """Compute the training loss of the LFADS autoencoder Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_bxt: np array of input with leading dims being batch and time keep_rate: dropout keep rate kl_scale: scale on KL Returns: a dictionary of all losses, including the key 'total' used for optimization """ B = lfads_hps['batch_size'] key, skeys = utils.keygen(key, 2) keys = random.split(next(skeys), B) lfads = batch_lfads(params, lfads_hps, keys, x_bxt, keep_rate) # Sum over time and state dims, average over batch. # KL - g0 ic_post_mean_b = lfads['ic_mean'] ic_post_logvar_b = lfads['ic_logvar'] kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b, params['ic_prior'], lfads_hps['var_min']) kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B kl_loss_g0 = kl_scale * kl_loss_g0_prescale # KL - Inferred input ii_post_mean_bxt = lfads['ii_mean_t'] ii_post_var_bxt = lfads['ii_logvar_t'] keys = random.split(next(skeys), B) kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys, ii_post_mean_bxt, ii_post_var_bxt, params['ii_prior'], lfads_hps['var_min']) kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B kl_loss_ii = kl_scale * kl_loss_ii_prescale # Log-likelihood of data given latents. lograte_bxt = lfads['lograte_t'] log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B # L2 l2reg = lfads_hps['l2reg'] flatten_lfads = lambda params: flatten_util.ravel_pytree(params)[0] l2_loss = l2reg * np.sum(flatten_lfads(params)**2) loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss all_losses = { 'total': loss, 'nlog_p_xgz': -log_p_xgz, 'kl_g0': kl_loss_g0, 'kl_g0_prescale': kl_loss_g0_prescale, 'kl_ii': kl_loss_ii, 'kl_ii_prescale': kl_loss_ii_prescale, 'l2': l2_loss } return all_losses
def construct_proxy_fn( prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1, ): ref_params = { name: biject_to(prototype_trace[name]["fn"].support).inv(value) for name, value in reference_params.items() } ref_params_flat, unravel_fn = ravel_pytree(ref_params) def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = { name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items() } with block(), trace() as tr, substitute( data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik def log_likelihood_sum(params_flat, subsample_indices=None): return { k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items() } # those stats are dict keyed by subsample names ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)( ref_params_flat) ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)( ref_params_flat) def gibbs_init(rng_key, gibbs_sites): ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) ref_subsample_log_lik_grads = jacfwd(log_likelihood)( ref_params_flat, gibbs_sites) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))( ref_params_flat, gibbs_sites) return TaylorProxyState( ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians, ) def gibbs_update(rng_key, gibbs_sites, gibbs_state): u_new, pads, new_idxs, starts = _block_update_proxy( num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) new_states = defaultdict(dict) ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) ref_subsample_log_lik_grads = jacfwd(log_likelihood)( ref_params_flat, new_idxs) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))( ref_params_flat, new_idxs) for stat, new_block_values, last_values in zip( ["log_liks", "grads", "hessians"], [ ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians, ], [ gibbs_state.ref_subsample_log_liks, gibbs_state.ref_subsample_log_lik_grads, gibbs_state.ref_subsample_log_lik_hessians, ], ): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] pad, start = pads[name], starts[name] new_value = jnp.pad( last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1), ) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) new_states[stat][name] = new_value[:subsample_size] gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) return u_new, gibbs_state def proxy_fn(params, subsample_lik_sites, gibbs_state): params_flat, _ = ravel_pytree(params) params_diff = params_flat - ref_params_flat ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) for name in subsample_lik_sites: proxy_subsample[name] = ( ref_subsample_log_liks[name] + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + 0.5 * jnp.dot( jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), params_diff, )) proxy_sum[name] = ( ref_log_likelihoods_sum[name] + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + 0.5 * jnp.dot( jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), params_diff, )) return proxy_sum, proxy_subsample return proxy_fn, gibbs_init, gibbs_update
def init_fn(z, rng_key): z_flat, _ = ravel_pytree(z) results = kernel.bootstrap_results(z_flat) return TFPKernelState(z, results, rng_key)
def _assertAllClose(self, x, y, rtol): x = ravel_pytree(x)[0] y = ravel_pytree(y)[0] diff = 2 * np.sum( np.abs(x - y)) / (np.sum(np.abs(x)) + np.sum(np.abs(y)) + 1e-4) self.assertLess(diff, rtol)
"epoch": epoch, "wallclock": time.time() - start_time }) return get_params(opt_state) # See https://github.com/google/jax/issues/7809. binarize = lambda arr: tree_map(lambda x: x > 0.5, arr) print("Training normal model...") everything_mask = tree_map(lambda x: jnp.ones_like(x, dtype=jnp.dtype("bool")), init_params) final_params = train(init_params, everything_mask, "no_mask") # Mask as was implemented in the original paper print("Training lottery ticket model...") final_params_flat, unravel = ravel_pytree(final_params) cutoff = jnp.percentile(jnp.abs(final_params_flat), config.remove_percentile) mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff)) train(init_params, mask, "lottery_mask") print("Training lottery ticket sign model...") # The lottery ticket mask but instead of using the initial weights, just use # the sign of the initial weights. mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff)) w0 = tree_map(lambda x, m: 0.01 * jnp.sign(x) * m, final_params, mask) train(w0, mask, "lottery_sign_mask") # Totally random mask print("Training random mask model...") mask = binarize( unravel(random.uniform(rp.poop(), final_params_flat.shape) > config.remove_percentile / 100))
def jacobian_builder(losses, params): grads_tree = jax.jacrev(losses)(params) grads, _ = flatten_util.ravel_pytree(grads_tree) grads = jnp.reshape(grads, (batch_size, int(grads.shape[0] / batch_size))) return grads
def dynamics_one_flat(t, aug, args): flat, _ = ravel_pytree(dynamics_one(t, unravel(aug), args)) return flat