def _update_sketched(self, g, w, m, v1, v2, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] beta2 = opt_params['second_moment_averaging'] weight_decay = opt_params['weight_decay'] shape = w.shape rank = len(shape) reshaped_accumulators = [ jnp.reshape(v1[i], self._expanded_shape(shape, i)) for i in range(rank) ] acc = self._minimum(reshaped_accumulators) is_beta2_1 = (beta2 == 1).astype(g.dtype) one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) acc = beta2 * acc + one_minus_beta2_except1 * g * g preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), jnp.zeros_like(acc)) pg = g * preconditioner if self._graft: v2_acc = self._minimum([ jnp.reshape(v2[i], self._expanded_shape(shape, i)) for i in range(rank) ]) v2_acc = v2_acc + g * g preconditioner_graft = jnp.where(v2_acc > 0.0, 1.0 / (jnp.sqrt(v2_acc) + 1e-16), jnp.zeros_like(v2_acc)) pg_graft = preconditioner_graft * g pg_norm = jnp.linalg.norm(pg) pg_graft_norm = jnp.linalg.norm(pg_graft) pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) pg = pg + w * weight_decay if self._has_momentum: m, update = self._momentum_update(pg, m, momentum) else: update = pg w = w - (learning_rate * update).astype(w.dtype) for i in range(len(v1)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(acc, axis=axes) v1[i] = dim_accumulator if self._graft: for i in range(len(v2)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(v2_acc, axis=axes) v2[i] = dim_accumulator return w, (m, v1, v2)
def representation_mask(mask): # mask shape (batch_size,4) mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim))) # mask shape (batch_size,4) mask = jnp.repeat(mask[..., jnp.newaxis], repeats=serializer.representation_length, axis=2) # mask shape (batch_size,4,representation_length) return mask
def _update_sketched(self, grads, weights, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = weights.shape rank = len(shape) reshaped_accumulators = [jnp.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = jnp.where(current_accumulator > 0.0, 1.0 / jnp.sqrt(current_accumulator), jnp.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m weights = weights - (learning_rate * m).astype(weights.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return weights, (m, v)
def representation_mask(mask): mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim))) return jnp.broadcast_to( mask[:, :, None], mask.shape + (serializer.representation_length,) )