Exemple #1
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    m1, w1, w2, b2 = self.weights
    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    # Q: check if we need bias and/or put relu after the m1 dot?
    mask_logits = jnp.dot(x, m1)
    # Softmax.
    mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)
    log_mask = mask_logits - mask_logsumexp
    mask = jnp.exp(log_mask)
    # Gumbel-softmax with straight-through discretization.
    # TODO(lukaszkaiser, chowdhery): Extract this block and share
    rng1, rng2 = fastmath.random.split(self.rng, 2)
    u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)
    g = -jnp.log(-jnp.log(u))
    selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)
    if self._mode == 'train':
      # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
      quant_mask = tl.one_hot(selected_experts, self._num_experts)
      quant_mask = fastmath.stop_gradient(quant_mask)
      quant_mask += mask - fastmath.stop_gradient(mask)  # straight-through
      # We will sometimes (50% of the batches) use the soft-mask instead of
      # the quantized mask to improve training stability (see the paper above).
      # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
      select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
      quant_mask = jnp.where(select > 0.0, quant_mask, mask)
    else:
      quant_mask = tl.one_hot(selected_experts, self._num_experts)
    quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
    quant_mask_shape = quant_mask.shape
    batch_size = quant_mask.shape[0]

    if self._mode == 'predict' and batch_size == 1:
      # This implementation mimicks inference for batch_size 1.
      start_idx = selected_experts[0] * self._n_elements_in_block
      # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
      w = fastmath.dynamic_slice(w1, [0, start_idx],
                                 [w1.shape[0], self._n_elements_in_block])
      mid = jnp.dot(x, w)
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
      v = fastmath.dynamic_slice(w2, [start_idx, 0],
                                 [self._n_elements_in_block, w2.shape[-1]])
      v = jnp.reshape(v, [self._n_elements_in_block, -1])
      res = jnp.dot(relu, v) + b2
    else:
      expanded_mask = jnp.broadcast_to(
          quant_mask,
          (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block))
      expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
      mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      res = jnp.dot(relu, w2) + b2

    return jnp.reshape(res, x_shape)  # un-flatten if needed
Exemple #2
0
def clip_grads(grad_tree, max_norm):
  """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
  norm = l2_norm(grad_tree)
  normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
  return layers.nested_map(grad_tree, normalize)
Exemple #3
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    m1, m2, mb, w1, w2, b2 = self.weights
    if self._mode != 'predict':
      w1 = jnp.reshape(w1.T, (-1, self._d_ff))
      w2 = jnp.reshape(w2, (self._d_ff, -1))
    x_shape = x.shape
    x = jnp.reshape(x, [-1, x_shape[-1]])  # Easier to operate on flattened x.

    # Q: should we add bias and/or put relu after the low-rank m1 dot?
    mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
    mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
    # Softmax.
    mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)
    log_mask = mask_logits - mask_logsumexp
    mask = jnp.exp(log_mask)
    # Gumbel-softmax with straight-through discretization.
    rng1, rng2 = fastmath.random.split(self.rng, 2)
    u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)
    g = -jnp.log(-jnp.log(u))
    quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
    if self._mode == 'train':
      # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
      quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
      quant_mask = fastmath.stop_gradient(quant_mask)
      quant_mask += mask - fastmath.stop_gradient(mask)  # straight-through
      # We will sometimes (quant_prob of the batches) use the soft-mask instead
      # of the quantized mask to improve training stability (see paper above).
      select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
      quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
      quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

    if self._mode == 'train':
      # In training, run full matmul to get benefits from the above tricks.
      mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      res = jnp.dot(relu, w2) + b2
    elif self._mode == 'predict':
      # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
      # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
      # This implementation mimicks inference. It's not efficient for large
      # size of joint_batch, but at inference that will be 1 most of the time.
      # Shapes:
      # quant_mask is [joint_batch, self._d1]
      # w1 is [d_model, self._d1, self._d2]
      # we'll index w1 with advanced numpy indexing, first range over
      # self._d1 times the batch size, second range being quant_mask
      batch_size = quant_mask.shape[0]
      idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
      # flatten indices and select from w1
      idx1 = jnp.reshape(idx1, [-1])
      idx2 = jnp.reshape(quant_mask, [-1])
      w = w1[idx1, idx2, :]  # now we have per-element weights with batch dim
      w = jnp.reshape(w, [batch_size, self._d1, -1])
      mid = jnp.einsum('ai,aji->aj', x, w)
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      # w2 is [self._d1, self._d2, d_model]
      v = w2[idx1, idx2, :]
      v = jnp.reshape(v, [batch_size, self._d1, -1])
      res = jnp.einsum('ai,aij->aj', relu, v) + b2
    else:
      quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
      quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
      mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
      relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
      res = jnp.dot(relu, w2) + b2

    return jnp.reshape(res, x_shape)  # un-flatten if needed
Exemple #4
0
 def relu(x):
   return jnp.where(x <= 0, jnp.zeros_like(x), x)
Exemple #5
0
 def _convert_to_nans(x, y):
     # if all values in y are non-zeros, return x; otherwise return 0s
     return jnp.where(jnp.all(y, keepdims=False), x, x / 0.), y
Exemple #6
0
 def non_nan(x):  # pylint: disable=invalid-name
     return jnp.where(jnp.isnan(x), 0., x)
Exemple #7
0
    def update(self, step, grads, weights, slots, opt_params):
        updates = []
        learning_rate = opt_params['learning_rate']
        beta1 = opt_params['beta1']
        decay_rate = opt_params['decay_rate']
        clipping_threshold = opt_params['clipping_threshold']
        weight_decay_rate = opt_params['weight_decay_rate']
        weight_decay_n_steps = opt_params['weight_decay_n_steps']
        weight_decay_rate = jnp.where(
            weight_decay_n_steps <
            1,  # if weight_decay_n_steps == 0, ignore it
            weight_decay_rate,
            (weight_decay_rate *
             jnp.maximum(weight_decay_n_steps - step, 0.0) /
             jnp.maximum(weight_decay_n_steps, 0.0)))
        epsilon1 = opt_params['epsilon1']
        epsilon2 = opt_params['epsilon2']
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)),
                                        epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads
        if self._factored and len(weights.shape) >= 2:
            v_row = slots[
                0]  # In this case, the slots are (v_row, v_col, ...).
            v_col = slots[1]
            new_v_row = (decay_rate * v_row +
                         mixing_rate * jnp.mean(grads_sqr, axis=-1))
            new_v_col = (decay_rate * v_col +
                         mixing_rate * jnp.mean(grads_sqr, axis=-2))
            updates.extend([new_v_row, new_v_col])
            row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (row_mean / (new_v_row + epsilon1))**0.5
            col_factor = (new_v_col + epsilon1)**-0.5
            y = (grads * jnp.expand_dims(row_factor, axis=-1) *
                 jnp.expand_dims(col_factor, axis=-2))
        else:
            v = slots[0]  # In this case, the slots are (v, ...)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v + epsilon1)**-0.5

        if self._do_clipping:
            clipping_denom = (jnp.maximum(
                1.0,
                jnp.sqrt(jnp.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots[-1]  # Momentum is always the last slot (if used).
            m = m.astype(subtrahend.dtype)  # Accumulate in subtrahend dtype.
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m.astype(
                slots[-1].dtype))  # Back to bfloat if needed.

        new_weights = (1 - weight_decay_rate) * weights - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_weights.astype(weights.dtype), updates