def apply_gradient(self, hyper_params, params, state, grads): """Applies a gradient for a set of parameters. Args: hyper_params: a named tuple of hyper parameters. params: the parameters that should be updated. state: a named tuple containing the state of the optimizer grads: the gradient tensors for the parameters. Returns: A tuple containing the new parameters and the new optimizer state. """ step = state.step params_flat, treedef = jax.tree_flatten(params) states_flat = treedef.flatten_up_to(state.param_states) grads_flat = treedef.flatten_up_to(grads) new_states_flat = [ self.compute_shampoo_statistics(step, hyper_params, param, state, grad) for param, state, grad in zip(params_flat, states_flat, grads_flat) ] new_states_flat = self.compute_preconditioners_from_statistics( new_states_flat, hyper_params, step) out = [ self.apply_per_param_gradient(step, hyper_params, param, state, grad) for param, state, grad in zip(params_flat, new_states_flat, grads_flat) ] new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) new_params = jax.tree_unflatten(treedef, new_params_flat) new_param_states = jax.tree_unflatten(treedef, new_states_flat) new_state = OptimizerState(step + 1, new_param_states) return new_params, new_state
def _optax_state_to_optim_state(self, optax_state): return OptimizerState(optax_state.count, optax_state.stats)