def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: number of epoch. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del epoch # unused if self.lmbda == 1: raise ValueError( 'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! ' ) if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # If slack hits zero, reset to be the current value. # This stops the method from halting. if state.slack == 0.0: state = state._replace(slack=value) ## The mathematical expression of the this update is: # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda) # w = w - 4 step s *grad, # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step) step_size = jax.nn.relu( value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / ( 4 * state.slack * tree_l2_norm(grad, squared=True) + 1 - self.lmbda) newslack = (1 - self.lmbda) * jnp.sqrt( state.slack) * (jnp.sqrt(state.slack) + step_size) step_size = 4 * state.slack * step_size if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSsqrtState(iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # Currently experimenting with decreasing lambda slowly after many iterations. # The intuition behind this is that in the early iterations SPS (self.lmbda=1) # works well. But in later iterations the slack helpd stabilize. late_start = 10 if self.lmbda_schedule and epoch > late_start: lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1) else: lmbdat = self.lmbda ## Mathematical description on this step size: # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda) step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm( grad, squared=True) + 1 - lmbdat) newslack = (1 - lmbdat) * (state.slack + step_size) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSDamState( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None gradnorm = tree_l2_norm(grad, squared=True) step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / ( self.delta + gradnorm) spsstep = value / gradnorm step_size = jnp.minimum(step1, spsstep) newslack = jax.nn.relu(state.slack - self.lmbda * self.delta + self.delta * step1) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSL1State( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, data, *args, **kwargs): """Performs one iteration of the optax solver. Args: params: pytree containing the parameters. state: named tuple containing the solver state. data: dict. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del args, kwargs # unused (value, aux), update = self._spsdiag_update(params, data) if self.momentum == 0: new_params = tree_add_scalar_mul(params, self.learning_rate, update) new_velocity = None else: new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(self.learning_rate, update)) new_params = tree_add(params, new_velocity) new_params = tree_add_scalar_mul( params, self.learning_rate, update) aux['loss'] = jnp.mean(aux['loss']) aux['accuracy'] = jnp.mean(aux['accuracy']) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = StochasticPolyakState( iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update_pytrees_CG(self, params, state, epoch, data, *args, **kwargs): """Solves one iteration of the system Polyak solver calling directly CG. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. data: a batch of data. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del epoch, args, kwargs # unused def losses(params): # Currently, self.fun returns the losses BEFORE the mean reduction. return self.fun(params, data)[0] # TODO(rmgower): avoid recomputing the auxiliary output (metrics) aux = self.fun(params, data)[1] # get Jacobian transpose operator Jt = jax.vjp(losses, params)[1] @jax.jit def matvec(u): """Matrix-vector product. Args: u: vectors of length batch_size Returns: K: vector (J J^T + delta * I)u = J(J^T(u)) +delta * u """ ## Important: This is slow Jtu = Jt(u) # evaluate Jacobian transpose vector product # evaluate Jacobian-vector product JJtu = jax.jvp(losses, (params, ), (Jtu[0], ))[1] deltau = self.delta * u return JJtu + deltau ## Solve the small linear system (J J^T +delta * I)x = -loss ## Warning: This is the bottleneck cost rhs = -losses(params) cg_sol = linear_solve.solve_cg(matvec, rhs, init=None, maxiter=20) ## Builds final solution w = w - J^T(J J^T +delta * I)^{-1}loss rhs = -losses(params) Jtsol = Jt(cg_sol)[0] new_params = tree_util.tree_add(params, Jtsol) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', jnp.mean(-rhs)) new_state = SystemStochasticPolyakState(iter_num=state.iter_num + 1, aux=aux) return base.OptStep(params=new_params, state=new_state)