def estimate_bwd(apply_fun, ctx, g): dLdz, dLdlogdet, _ = g x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx x_shape, batch_shape = batch_info batch_axes = tuple(range(len(batch_shape))) dLdtheta = jax.tree_util.tree_map( lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x, dlogdet_dtheta) dLdx = jax.tree_util.tree_map( lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x, dlogdet_dx) # Reduce over the batch axes if len(batch_axes) > 0: dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta) with hk_base.frame_stack( CustomFrame.create_from_params_and_state(params, state)): # Compute the partial derivatives wrt x _, vjp_fun = jax.vjp( lambda params, x: x + apply_fun(params, state, x, rng)[0], params, x, has_aux=False) dtheta, dx = vjp_fun(dLdz) # Combine the partial derivatives dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta) dLdx = jax.tree_multimap(lambda x, y: x + y, dLdx, dx) return dLdtheta, None, dLdx, None, None
def log_det_sliced_estimate(apply_fun, params, state, x, rng, batch_info): trace_key, roulette_key = random.split(rng, 2) # Evaluate the flow and get the vjp function gx, state = apply_fun(params, state, x, rng, update_params=True) _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), x, has_aux=True) z = x + gx # Generate the probe vector for the trace estimate v = random.normal(trace_key, x.shape) # Get all of the terms we need for the log det and gradient estimates terms = unbiased_neumann_vjp_terms(vjp_fun, v, roulette_key, n_terms=7, n_exact=4) # Rescale the terms and sum over k (starting at k=1) cut_terms = terms[1:] log_det_coeff = -1/(1 + jnp.arange(cut_terms.shape[0])) log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim) log_det_terms = log_det_coeff*cut_terms summed_log_det_terms = log_det_terms.sum(axis=0) # Compute the log det x_shape, batch_shape = batch_info log_det = jnp.sum(summed_log_det_terms*v, axis=util.last_axes(x_shape)) return z, log_det, v, terms, state
def log_det_estimate(apply_fun, params, state, x, rng, batch_info): x_shape, batch_shape = batch_info assert len(x_shape) == 1, "Not going to implement this for images" gx, state = apply_fun(params, state, x, rng) z = x + gx # Compute the jacobian of the transform jac_fun = jax.jacobian( lambda x: apply_fun(params, state, x[None], rng)[0][0]) vmap_trace = jnp.trace for i in range(len(batch_shape)): jac_fun = vmap(jac_fun) vmap_trace = vmap(vmap_trace) J = jac_fun(x) # Generate the terms of the neumann series terms = neumann_jacobian_terms(J, rng, n_terms=10, n_exact=10) # Rescale the terms and sum over k (starting at k=1) cut_terms = terms[1:] log_det_coeff = -1 / (1 + jnp.arange(cut_terms.shape[0])) log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim) log_det_terms = log_det_coeff * cut_terms summed_log_det_terms = log_det_terms.sum(axis=0) # Compute the log det log_det = vmap_trace(summed_log_det_terms) return z, log_det, terms, state
def neumann_jacobian_terms(J, rng, n_terms=10, n_exact=4): terms = jacobian_power_iterations(J, n_terms) # Compute the coefficients for each term coeff = unbiased_neumann_coefficients(rng, n_terms, n_exact) coeff = util.broadcast_to_first_axis(coeff, terms.ndim) return coeff*terms
def unbiased_neumann_vjp_terms(vjp_fun, v, rng, n_terms=10, n_exact=4): # This function assumes that we start at k=0! # Compute the terms in the power series. terms = vjp_iterations(vjp_fun, v, n_terms) # Compute the coefficients for each term coeff = unbiased_neumann_coefficients(rng, n_terms, n_exact) coeff = util.broadcast_to_first_axis(coeff, terms.ndim) return coeff * terms
def multiply_by_val(x): return util.broadcast_to_first_axis(dLdlogdet, x.ndim)*x