Exemple #1
0
def temporary_internal_state(state: InternalState):
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    frame = base.current_frame()
    frame = frame.evolve(params=state.params, state=state.state, rng=rng)
    return base.frame_stack(frame)
def sliced_estimate_bwd(apply_fun, ctx, g):
  dLdz, dLdlogdet, _ = g
  x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx
  x_shape, batch_shape = batch_info
  batch_dim = len(batch_shape)
  batch_axes = tuple(range(batch_dim))

  if batch_dim > 0:
    def multiply_by_val(x):
      return util.broadcast_to_first_axis(dLdlogdet, x.ndim)*x
  else:
    def multiply_by_val(x):
      return dLdlogdet*x

  # Get the gradients wrt x and theta using the terms from the forward step
  dLdtheta = jax.tree_util.tree_map(multiply_by_val, dlogdet_dtheta)
  dLdx     = jax.tree_util.tree_map(multiply_by_val, dlogdet_dx)

  # Reduce over the batch axes
  if batch_dim > 0:
    dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta)

  # Open up a new frame so that apply_fun can retrieve the parameters
  with hk_base.frame_stack(CustomFrame.create_from_params_and_state(params, state)):

    # Add in the partial derivatives wrt x
    _, vjp_fun = jax.vjp(lambda params, x: x + apply_fun(params, state, x, rng, update_params=False)[0], params, x, has_aux=False)
    dtheta, dx = vjp_fun(dLdz)

    # Combine the partial derivatives
    dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta)
    dLdx     = jax.tree_multimap(lambda x, y: x + y, dLdx, dx)

    return dLdtheta, None, dLdx, None, None
def estimate_bwd(apply_fun, ctx, g):
    dLdz, dLdlogdet, _ = g
    x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx
    x_shape, batch_shape = batch_info
    batch_axes = tuple(range(len(batch_shape)))

    dLdtheta = jax.tree_util.tree_map(
        lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x,
        dlogdet_dtheta)
    dLdx = jax.tree_util.tree_map(
        lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x,
        dlogdet_dx)

    # Reduce over the batch axes
    if len(batch_axes) > 0:
        dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta)

    with hk_base.frame_stack(
            CustomFrame.create_from_params_and_state(params, state)):

        # Compute the partial derivatives wrt x
        _, vjp_fun = jax.vjp(
            lambda params, x: x + apply_fun(params, state, x, rng)[0],
            params,
            x,
            has_aux=False)
        dtheta, dx = vjp_fun(dLdz)

        # Combine the partial derivatives
        dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta)
        dLdx = jax.tree_multimap(lambda x, y: x + y, dLdx, dx)

        return dLdtheta, None, dLdx, None, None
Exemple #4
0
def temporary_frame_data(frame_data: FrameData):
  """Pushes a temporary copy of the frame_data."""
  frame_data = copy_structure(frame_data)
  rng = frame_data.rng if frame_data.rng is None else PRNGSequence(frame_data.rng)
  params = frame_data.params
  state = frame_data.state
  constants = frame_data.constants
  assert params is not None, "Must initialize module before this call"
  assert state is not None, "Must initialize module before this call"
  assert constants is not None, "Must initialize module before this call"

  frame = current_frame()
  frame = frame.evolve(params=params, state=state, constants=constants, rng=rng)
  return frame_stack(frame)
Exemple #5
0
def temporary_internal_state(state: InternalState):
    """Pushes a temporary copy of the internal state."""
    state = copy_structure(state)
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    current_state = internal_state()
    params = state.params
    if params is None:
        params = current_state.params
    state = state.state
    if state is None:
        state = current_state.state
    frame = base.current_frame()
    frame = frame.evolve(params=params, state=state, rng=rng)
    return base.frame_stack(frame)
def fixed_point_bwd(apply_fun, ctx, dLdx):
    assert 0, "Bro, did you seriously just try to backprop through a fixed point iteration?"
    params, state, x, z, roulette_rng = ctx

    with hk_base.frame_stack(
            CustomFrame.create_from_params_and_state(params, state)):

        _, vjp_x = jax.vjp(
            lambda x: contractive_fixed_point(apply_fun, params, state, x, z),
            x)

        # _, vjp_x = jax.vjp(lambda x: steffensen_fixed_point(apply_fun, params, state, x, z), x)

        def rev_iter(zeta):
            zetaT_dFdx, = vjp_x(zeta)
            return dLdx + zetaT_dFdx

        zeta, N = _fixed_point(rev_iter, dLdx, 100, 1e-4)

        # Go from zeta to the gradient of the frame data
        _, vjp_u = jax.vjp(
            lambda params: contractive_fixed_point(apply_fun, params, state, x,
                                                   z), params)
        # _, vjp_u = jax.vjp(lambda params: steffensen_fixed_point(apply_fun, params, state, x, z), params)
        dparams, = vjp_u(zeta)

        # Also handle the gradient wrt z here.  To do this, we need to solve (dx/dz)^{-1}dx.
        # Do this with vjps against terms in the neumann series for dx/dz
        _, vjp_x = jax.vjp(lambda x: apply_fun(params, state, x)[0],
                           x,
                           has_aux=False)
        terms = unbiased_neumann_vjp_terms(vjp_x,
                                           dLdx,
                                           roulette_rng,
                                           n_terms=40,
                                           n_exact=40)
        dx_star = terms.sum(axis=0)

        return dparams, None, dx_star, None