Example #1
0
    def _update_sketched(self, g, w, m, v1, v2, opt_params):
        """Update for higher-rank parameters."""
        learning_rate = opt_params['learning_rate']
        momentum = opt_params['momentum']
        beta2 = opt_params['second_moment_averaging']
        weight_decay = opt_params['weight_decay']

        shape = w.shape
        rank = len(shape)
        reshaped_accumulators = [
            jnp.reshape(v1[i], self._expanded_shape(shape, i))
            for i in range(rank)
        ]
        acc = self._minimum(reshaped_accumulators)

        is_beta2_1 = (beta2 == 1).astype(g.dtype)
        one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 -
                                                                is_beta2_1)
        acc = beta2 * acc + one_minus_beta2_except1 * g * g

        preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16),
                                   jnp.zeros_like(acc))
        pg = g * preconditioner
        if self._graft:
            v2_acc = self._minimum([
                jnp.reshape(v2[i], self._expanded_shape(shape, i))
                for i in range(rank)
            ])
            v2_acc = v2_acc + g * g
            preconditioner_graft = jnp.where(v2_acc > 0.0,
                                             1.0 / (jnp.sqrt(v2_acc) + 1e-16),
                                             jnp.zeros_like(v2_acc))
            pg_graft = preconditioner_graft * g
            pg_norm = jnp.linalg.norm(pg)
            pg_graft_norm = jnp.linalg.norm(pg_graft)
            pg = pg * (pg_graft_norm / (pg_norm + 1e-16))

        pg = pg + w * weight_decay

        if self._has_momentum:
            m, update = self._momentum_update(pg, m, momentum)
        else:
            update = pg

        w = w - (learning_rate * update).astype(w.dtype)
        for i in range(len(v1)):
            axes = list(range(int(i))) + list(range(int(i) + 1, rank))
            dim_accumulator = jnp.amax(acc, axis=axes)
            v1[i] = dim_accumulator

        if self._graft:
            for i in range(len(v2)):
                axes = list(range(int(i))) + list(range(int(i) + 1, rank))
                dim_accumulator = jnp.amax(v2_acc, axis=axes)
                v2[i] = dim_accumulator
        return w, (m, v1, v2)
Example #2
0
 def representation_mask(mask):
     # mask shape (batch_size,4)
     mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim)))
     # mask shape (batch_size,4)
     mask = jnp.repeat(mask[..., jnp.newaxis],
                       repeats=serializer.representation_length,
                       axis=2)
     # mask shape (batch_size,4,representation_length)
     return mask
Example #3
0
 def _update_sketched(self, grads, weights, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = weights.shape
   rank = len(shape)
   reshaped_accumulators = [jnp.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = jnp.where(current_accumulator > 0.0,
                                    1.0 / jnp.sqrt(current_accumulator),
                                    jnp.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = jnp.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return weights, (m, v)
Example #4
0
 def representation_mask(mask):
   mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim)))
   return jnp.broadcast_to(
       mask[:, :, None], mask.shape + (serializer.representation_length,)
   )