def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: number of epoch. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del epoch # unused if self.lmbda == 1: raise ValueError( 'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! ' ) if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # If slack hits zero, reset to be the current value. # This stops the method from halting. if state.slack == 0.0: state = state._replace(slack=value) ## The mathematical expression of the this update is: # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda) # w = w - 4 step s *grad, # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step) step_size = jax.nn.relu( value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / ( 4 * state.slack * tree_l2_norm(grad, squared=True) + 1 - self.lmbda) newslack = (1 - self.lmbda) * jnp.sqrt( state.slack) * (jnp.sqrt(state.slack) + step_size) step_size = 4 * state.slack * step_size if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSsqrtState(iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # Currently experimenting with decreasing lambda slowly after many iterations. # The intuition behind this is that in the early iterations SPS (self.lmbda=1) # works well. But in later iterations the slack helpd stabilize. late_start = 10 if self.lmbda_schedule and epoch > late_start: lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1) else: lmbdat = self.lmbda ## Mathematical description on this step size: # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda) step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm( grad, squared=True) + 1 - lmbdat) newslack = (1 - lmbdat) * (state.slack + step_size) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSDamState( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None gradnorm = tree_l2_norm(grad, squared=True) step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / ( self.delta + gradnorm) spsstep = value / gradnorm step_size = jnp.minimum(step1, spsstep) newslack = jax.nn.relu(state.slack - self.lmbda * self.delta + self.delta * step1) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSL1State( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, data, *args, **kwargs): """Performs one iteration of the optax solver. Args: params: pytree containing the parameters. state: named tuple containing the solver state. data: dict. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del args, kwargs # unused (value, aux), update = self._spsdiag_update(params, data) if self.momentum == 0: new_params = tree_add_scalar_mul(params, self.learning_rate, update) new_velocity = None else: new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(self.learning_rate, update)) new_params = tree_add(params, new_velocity) new_params = tree_add_scalar_mul( params, self.learning_rate, update) aux['loss'] = jnp.mean(aux['loss']) aux['accuracy'] = jnp.mean(aux['accuracy']) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = StochasticPolyakState( iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)