예제 #1
0
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   (learning_rate, momentum) = opt_params
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
예제 #2
0
 def _update_sketched(self, step, g, x, m, v):
     """Update for higher-rank parameters."""
     shape = x.shape
     rank = len(shape)
     reshaped_accumulators = [
         np.reshape(v[i], self._expanded_shape(shape, i))
         for i in range(rank)
     ]
     current_accumulator = self._minimum(reshaped_accumulators)
     current_accumulator += g * g
     accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                     1.0 / np.sqrt(current_accumulator),
                                     np.zeros_like(current_accumulator))
     preconditioned_gradient = g * accumulator_inv_sqrt
     m = (1.0 -
          self._momentum) * preconditioned_gradient + self._momentum * m
     x = x - self.step_size(step) * m
     for i in range(len(v)):
         axes = list(range(int(i))) + list(range(int(i) + 1, rank))
         dim_accumulator = np.amax(current_accumulator, axis=axes)
         v[i] = dim_accumulator
     return x, (m, v)