def estimate_bwd(apply_fun, ctx, g):
    dLdz, dLdlogdet, _ = g
    x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx
    x_shape, batch_shape = batch_info
    batch_axes = tuple(range(len(batch_shape)))

    dLdtheta = jax.tree_util.tree_map(
        lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x,
        dlogdet_dtheta)
    dLdx = jax.tree_util.tree_map(
        lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x,
        dlogdet_dx)

    # Reduce over the batch axes
    if len(batch_axes) > 0:
        dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta)

    with hk_base.frame_stack(
            CustomFrame.create_from_params_and_state(params, state)):

        # Compute the partial derivatives wrt x
        _, vjp_fun = jax.vjp(
            lambda params, x: x + apply_fun(params, state, x, rng)[0],
            params,
            x,
            has_aux=False)
        dtheta, dx = vjp_fun(dLdz)

        # Combine the partial derivatives
        dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta)
        dLdx = jax.tree_multimap(lambda x, y: x + y, dLdx, dx)

        return dLdtheta, None, dLdx, None, None
def log_det_sliced_estimate(apply_fun,
                            params,
                            state,
                            x,
                            rng,
                            batch_info):
  trace_key, roulette_key = random.split(rng, 2)

  # Evaluate the flow and get the vjp function
  gx, state = apply_fun(params, state, x, rng, update_params=True)
  _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), x, has_aux=True)
  z = x + gx

  # Generate the probe vector for the trace estimate
  v = random.normal(trace_key, x.shape)

  # Get all of the terms we need for the log det and gradient estimates
  terms = unbiased_neumann_vjp_terms(vjp_fun, v, roulette_key, n_terms=7, n_exact=4)

  # Rescale the terms and sum over k (starting at k=1)
  cut_terms = terms[1:]
  log_det_coeff = -1/(1 + jnp.arange(cut_terms.shape[0]))
  log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim)
  log_det_terms = log_det_coeff*cut_terms
  summed_log_det_terms = log_det_terms.sum(axis=0)

  # Compute the log det
  x_shape, batch_shape = batch_info
  log_det = jnp.sum(summed_log_det_terms*v, axis=util.last_axes(x_shape))

  return z, log_det, v, terms, state
def log_det_estimate(apply_fun, params, state, x, rng, batch_info):
    x_shape, batch_shape = batch_info
    assert len(x_shape) == 1, "Not going to implement this for images"

    gx, state = apply_fun(params, state, x, rng)
    z = x + gx

    # Compute the jacobian of the transform
    jac_fun = jax.jacobian(
        lambda x: apply_fun(params, state, x[None], rng)[0][0])
    vmap_trace = jnp.trace
    for i in range(len(batch_shape)):
        jac_fun = vmap(jac_fun)
        vmap_trace = vmap(vmap_trace)

    J = jac_fun(x)

    # Generate the terms of the neumann series
    terms = neumann_jacobian_terms(J, rng, n_terms=10, n_exact=10)

    # Rescale the terms and sum over k (starting at k=1)
    cut_terms = terms[1:]
    log_det_coeff = -1 / (1 + jnp.arange(cut_terms.shape[0]))
    log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim)
    log_det_terms = log_det_coeff * cut_terms
    summed_log_det_terms = log_det_terms.sum(axis=0)

    # Compute the log det
    log_det = vmap_trace(summed_log_det_terms)

    return z, log_det, terms, state
def neumann_jacobian_terms(J, rng, n_terms=10, n_exact=4):

  terms = jacobian_power_iterations(J, n_terms)

  # Compute the coefficients for each term
  coeff = unbiased_neumann_coefficients(rng, n_terms, n_exact)
  coeff = util.broadcast_to_first_axis(coeff, terms.ndim)

  return coeff*terms
Exemple #5
0
def unbiased_neumann_vjp_terms(vjp_fun, v, rng, n_terms=10, n_exact=4):
    # This function assumes that we start at k=0!

    # Compute the terms in the power series.
    terms = vjp_iterations(vjp_fun, v, n_terms)

    # Compute the coefficients for each term
    coeff = unbiased_neumann_coefficients(rng, n_terms, n_exact)
    coeff = util.broadcast_to_first_axis(coeff, terms.ndim)

    return coeff * terms
 def multiply_by_val(x):
   return util.broadcast_to_first_axis(dLdlogdet, x.ndim)*x