def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" (learning_rate, momentum) = opt_params shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def _update_sketched(self, step, g, x, m, v): """Update for higher-rank parameters.""" shape = x.shape rank = len(shape) reshaped_accumulators = [ np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank) ] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += g * g accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = g * accumulator_inv_sqrt m = (1.0 - self._momentum) * preconditioned_gradient + self._momentum * m x = x - self.step_size(step) * m for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return x, (m, v)