def temporary_internal_state(state: InternalState): rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) frame = base.current_frame() frame = frame.evolve(params=state.params, state=state.state, rng=rng) return base.frame_stack(frame)
def sliced_estimate_bwd(apply_fun, ctx, g): dLdz, dLdlogdet, _ = g x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx x_shape, batch_shape = batch_info batch_dim = len(batch_shape) batch_axes = tuple(range(batch_dim)) if batch_dim > 0: def multiply_by_val(x): return util.broadcast_to_first_axis(dLdlogdet, x.ndim)*x else: def multiply_by_val(x): return dLdlogdet*x # Get the gradients wrt x and theta using the terms from the forward step dLdtheta = jax.tree_util.tree_map(multiply_by_val, dlogdet_dtheta) dLdx = jax.tree_util.tree_map(multiply_by_val, dlogdet_dx) # Reduce over the batch axes if batch_dim > 0: dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta) # Open up a new frame so that apply_fun can retrieve the parameters with hk_base.frame_stack(CustomFrame.create_from_params_and_state(params, state)): # Add in the partial derivatives wrt x _, vjp_fun = jax.vjp(lambda params, x: x + apply_fun(params, state, x, rng, update_params=False)[0], params, x, has_aux=False) dtheta, dx = vjp_fun(dLdz) # Combine the partial derivatives dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta) dLdx = jax.tree_multimap(lambda x, y: x + y, dLdx, dx) return dLdtheta, None, dLdx, None, None
def estimate_bwd(apply_fun, ctx, g): dLdz, dLdlogdet, _ = g x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx = ctx x_shape, batch_shape = batch_info batch_axes = tuple(range(len(batch_shape))) dLdtheta = jax.tree_util.tree_map( lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x, dlogdet_dtheta) dLdx = jax.tree_util.tree_map( lambda x: util.broadcast_to_first_axis(dLdlogdet, x.ndim) * x, dlogdet_dx) # Reduce over the batch axes if len(batch_axes) > 0: dLdtheta = jax.tree_map(lambda x: x.sum(axis=batch_axes), dLdtheta) with hk_base.frame_stack( CustomFrame.create_from_params_and_state(params, state)): # Compute the partial derivatives wrt x _, vjp_fun = jax.vjp( lambda params, x: x + apply_fun(params, state, x, rng)[0], params, x, has_aux=False) dtheta, dx = vjp_fun(dLdz) # Combine the partial derivatives dLdtheta = jax.tree_multimap(lambda x, y: x + y, dLdtheta, dtheta) dLdx = jax.tree_multimap(lambda x, y: x + y, dLdx, dx) return dLdtheta, None, dLdx, None, None
def temporary_frame_data(frame_data: FrameData): """Pushes a temporary copy of the frame_data.""" frame_data = copy_structure(frame_data) rng = frame_data.rng if frame_data.rng is None else PRNGSequence(frame_data.rng) params = frame_data.params state = frame_data.state constants = frame_data.constants assert params is not None, "Must initialize module before this call" assert state is not None, "Must initialize module before this call" assert constants is not None, "Must initialize module before this call" frame = current_frame() frame = frame.evolve(params=params, state=state, constants=constants, rng=rng) return frame_stack(frame)
def temporary_internal_state(state: InternalState): """Pushes a temporary copy of the internal state.""" state = copy_structure(state) rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) current_state = internal_state() params = state.params if params is None: params = current_state.params state = state.state if state is None: state = current_state.state frame = base.current_frame() frame = frame.evolve(params=params, state=state, rng=rng) return base.frame_stack(frame)
def fixed_point_bwd(apply_fun, ctx, dLdx): assert 0, "Bro, did you seriously just try to backprop through a fixed point iteration?" params, state, x, z, roulette_rng = ctx with hk_base.frame_stack( CustomFrame.create_from_params_and_state(params, state)): _, vjp_x = jax.vjp( lambda x: contractive_fixed_point(apply_fun, params, state, x, z), x) # _, vjp_x = jax.vjp(lambda x: steffensen_fixed_point(apply_fun, params, state, x, z), x) def rev_iter(zeta): zetaT_dFdx, = vjp_x(zeta) return dLdx + zetaT_dFdx zeta, N = _fixed_point(rev_iter, dLdx, 100, 1e-4) # Go from zeta to the gradient of the frame data _, vjp_u = jax.vjp( lambda params: contractive_fixed_point(apply_fun, params, state, x, z), params) # _, vjp_u = jax.vjp(lambda params: steffensen_fixed_point(apply_fun, params, state, x, z), params) dparams, = vjp_u(zeta) # Also handle the gradient wrt z here. To do this, we need to solve (dx/dz)^{-1}dx. # Do this with vjps against terms in the neumann series for dx/dz _, vjp_x = jax.vjp(lambda x: apply_fun(params, state, x)[0], x, has_aux=False) terms = unbiased_neumann_vjp_terms(vjp_x, dLdx, roulette_rng, n_terms=40, n_exact=40) dx_star = terms.sum(axis=0) return dparams, None, dx_star, None