def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: number of epoch. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del epoch # unused if self.lmbda == 1: raise ValueError( 'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! ' ) if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # If slack hits zero, reset to be the current value. # This stops the method from halting. if state.slack == 0.0: state = state._replace(slack=value) ## The mathematical expression of the this update is: # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda) # w = w - 4 step s *grad, # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step) step_size = jax.nn.relu( value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / ( 4 * state.slack * tree_l2_norm(grad, squared=True) + 1 - self.lmbda) newslack = (1 - self.lmbda) * jnp.sqrt( state.slack) * (jnp.sqrt(state.slack) + step_size) step_size = 4 * state.slack * step_size if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSsqrtState(iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None # Currently experimenting with decreasing lambda slowly after many iterations. # The intuition behind this is that in the early iterations SPS (self.lmbda=1) # works well. But in later iterations the slack helpd stabilize. late_start = 10 if self.lmbda_schedule and epoch > late_start: lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1) else: lmbdat = self.lmbda ## Mathematical description on this step size: # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda) step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm( grad, squared=True) + 1 - lmbdat) newslack = (1 - lmbdat) * (state.slack + step_size) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSDamState( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update_arrays_CG(self, params, state, data, *args, **kwargs): """Perform the update using CG.""" del kwargs # unused batch_size = data['label'].shape[0] _, unravel_pytree = flatten_util.ravel_pytree(params) values = jnp.zeros((batch_size)) @jax.jit def loss_sample(image, label): tmp_kwargs = {'data': {'image': image, 'label': label}} # compute a gradient on a single image/label pair if self.has_aux: # we only store the last value of aux (value_i, aux), grad_i = self._value_and_grad_fun( params, *args, **tmp_kwargs) else: value_i, grad_i = self._value_and_grad_fun( params, *args, **tmp_kwargs) aux = None grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i) return value_i, aux, grad_i_flatten @jax.jit def matvec_array(u): """Computes the product (J J^T +delta * I)u .""" out = grads @ (u @ grads) + self.delta * u return out # We add a new axis on data and labels so they have the correct # shape after vectorization by vmap, which removes the batch dimension ## Important: This is the bottleneck cost of this update! expand_data = jnp.expand_dims(data['image'], axis=1) expand_labels = jnp.expand_dims(data['label'], axis=1) values, aux, grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data, expand_labels) grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data, expand_labels)[2] # Solving v =(J J^T +delta * I)^{-1}loss v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=20) ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss v = v @ grads v_tree = unravel_pytree(v) new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree) value = jnp.mean(values) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = SystemStochasticPolyakState( # iter_num=state.iter_num + 1, value=value, aux=aux) iter_num=state.iter_num + 1, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, data, *args, **kwargs): """Performs one iteration of the optax solver. Args: params: pytree containing the parameters. state: named tuple containing the solver state. data: dict. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ del args, kwargs # unused (value, aux), update = self._spsdiag_update(params, data) if self.momentum == 0: new_params = tree_add_scalar_mul(params, self.learning_rate, update) new_velocity = None else: new_velocity = tree_sub( tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(self.learning_rate, update)) new_params = tree_add(params, new_velocity) new_params = tree_add_scalar_mul( params, self.learning_rate, update) aux['loss'] = jnp.mean(aux['loss']) aux['accuracy'] = jnp.mean(aux['accuracy']) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = StochasticPolyakState( iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update_arrays_lstsq(self, params, state, data, *args, **kwargs): """Perform the update using a least square solver.""" del kwargs # unused # This version makes use of the least-squares solver jnp.linalg.lstsq # which has two problems # 1. It is too slow because it computes a full svd (overkill) to solve # the systems # 2. It has no support for regularization batch_size = data['label'].shape[0] _, unravel_pytree = flatten_util.ravel_pytree(params) values = jnp.zeros((batch_size)) @jax.jit def loss_sample(image, label): tmp_kwargs = {'data': {'image': image, 'label': label}} # compute a gradient on a single image/label pair if self.has_aux: # we only store the last value of aux (value_i, aux), grad_i = self._value_and_grad_fun( params, *args, **tmp_kwargs) else: value_i, grad_i = self._value_and_grad_fun( params, *args, **tmp_kwargs) aux = None grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i) return value_i, aux, grad_i_flatten # we add a new axis on data and labels so they have the correct # shape after vectorization by vmap, which removes the batch dimension expand_data = jnp.expand_dims(data['image'], axis=1) expand_labels = jnp.expand_dims(data['label'], axis=1) values, aux, grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data, expand_labels) grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data, expand_labels)[2] # This is too slow. Need faster implementation v = jnp.linalg.lstsq(grads, values)[0] v_tree = unravel_pytree(v) new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree) value = jnp.mean(values) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = SystemStochasticPolyakState( # iter_num=state.iter_num + 1, value=value, aux=aux) iter_num=state.iter_num + 1, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update_jacrev_arrays_CG(self, params, state, data, *args, **kwargs): """Perform the update using jacrev and CG.""" del args, kwargs # unused # Currently the fastest implementation. batch_size = data['label'].shape[0] _, unravel_pytree = flatten_util.ravel_pytree(params) values = jnp.zeros((batch_size)) @jax.jit def losses(params): # Currently, self.fun returns the losses BEFORE the mean reduction. return self.fun(params, data)[0] # TODO(rmgower): avoid recomputing the auxiliary output (metrics) aux = self.fun(params, data)[1] @jax.jit def matvec_array(u): # Computes the product (J J^T +delta * I)u out = grads @ (u @ grads) + self.delta * u return out def jacobian_builder(losses, params): grads_tree = jax.jacrev(losses)(params) grads, _ = flatten_util.ravel_pytree(grads_tree) grads = jnp.reshape(grads, (batch_size, int(grads.shape[0] / batch_size))) return grads ## Important: This is the bottleneck cost of this update! grads = jacobian_builder(losses, params) values = losses(params) # Solving v =(J J^T +delta * I)^{-1}loss v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=10) ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss v = v @ grads v_tree = unravel_pytree(v) new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree) value = jnp.mean(values) if state.iter_num % 10 == 0: print('Number of iterations', state.iter_num, '. Objective function value: ', value) new_state = SystemStochasticPolyakState( # iter_num=state.iter_num + 1, value=value, aux=aux) iter_num=state.iter_num + 1, aux=aux) return base.OptStep(params=new_params, state=new_state)
def update(self, params, state, epoch, *args, **kwargs): """Perform one update of the algorithm. Args: params: pytree containing the parameters. state: named tuple containing the solver state. epoch: int. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Return type: base.OptStep Returns: (params, state) """ if self.has_aux: (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) else: value, grad = self._value_and_grad_fun(params, *args, **kwargs) aux = None gradnorm = tree_l2_norm(grad, squared=True) step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / ( self.delta + gradnorm) spsstep = value / gradnorm step_size = jnp.minimum(step1, spsstep) newslack = jax.nn.relu(state.slack - self.lmbda * self.delta + self.delta * step1) # new_params = tree_add_scalar_mul(params, -step_size, grad) if self.momentum == 0: new_params = tree_add_scalar_mul(params, -step_size, grad) new_velocity = None else: # new_v = momentum * v - step_size * grad # new_params = params + new_v new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity), tree_scalar_mul(step_size, grad)) new_params = tree_add(params, new_velocity) new_state = SPSL1State( iter_num=state.iter_num + 1, value=value, slack=newslack, velocity=new_velocity, aux=aux) return base.OptStep(params=new_params, state=new_state)
def projection_halfspace(x, a, b): r"""Projection onto a halfspace defined by a pytree and scalar. The output is: ``argmin_{y, dot(a, y) <= b} ||y - x||``. Args: x: pytree to project. a: pytree b: pytree Returns: y: output array (same shape as ``x``) """ # a, b = hyperparams scale = jax.nn.relu(tree_util.tree_vdot(a, x) - b) / tree_util.tree_vdot(a, a) return tree_util.tree_add_scalar_mul(x, -scale, a)
def projection_hyperplane(a, b, x = None): r"""Projection onto a hyperplane defined by a pytree and scalar. The output is: ``argmin_{y, dot(a, y) = b} ||y - x||``. Which is equivalent to y = x - (<a,x>-b)/<a,a> a Args: x: pytree to project. hyperparams: tuple ``hyperparams = (a, b)``, where ``a`` is a pytree and ``b`` is a scalar. Returns: y: output array (same shape as ``x``) """ if x is None: scale = b/tree_util.tree_vdot(a,a) return tree_util.tree_scalar_mul(scale, a) else: scale = (tree_util.tree_vdot(a,x) -b)/tree_util.tree_vdot(a,a) return tree_util.tree_add_scalar_mul(x, -scale, a)