예제 #1
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
예제 #2
0
 def update(self, step, grads, params, avg_sq_grad, opt_params):
     del step
     (learning_rate, gamma, eps) = opt_params
     avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
     params = params - (learning_rate * grads /
                        (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype)
     return params, avg_sq_grad
예제 #3
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
예제 #4
0
def dot_product_attention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate - keep probability
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        dots = np.where(mask, dots, -1e9)
    dots = stax.softmax(dots, axis=-1)
    if dropout is not None and mode == 'train':
        keep = random.bernoulli(rng, dropout, dots.shape)
        dots = np.where(keep, dots / dropout, 0)
    out = np.matmul(dots, value)
    return out
예제 #5
0
def BatchNorm(x,
              params,
              axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              **unused_kwargs):
    """Layer construction function for a batch normalization layer."""
    mean = np.mean(x, axis, keepdims=True)
    # Fast but less numerically-stable variance calculation than np.var.
    m1 = np.mean(x**2, axis, keepdims=True)
    var = m1 - mean**2
    z = (x - mean) / np.sqrt(var + epsilon)

    # Expand the parameters to have the right axes.
    beta, gamma = params
    # TODO(phawkins): np.expand_dims should accept an axis tuple.
    # (https://github.com/numpy/numpy/issues/12290)
    ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
    beta = beta[ed]
    gamma = gamma[ed]

    # Return the z rescaled by the parameters if requested.
    if center and scale:
        return gamma * z + beta
    if center:
        return z + beta
    if scale:
        return gamma * z
    return z
예제 #6
0
 def apply_fun(params, inputs, **kwargs):
     del kwargs
     (scale, bias) = params
     mean = np.mean(inputs, axis=-1, keepdims=True)
     variance = np.mean((inputs - mean)**2, axis=-1, keepdims=True)
     norm_inputs = (inputs - mean) / np.sqrt(variance + epsilon)
     return norm_inputs * scale + bias
예제 #7
0
        def binned_attn(sq, sk, sv):
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = chunk_rank3(sq_t)
            bkv_t = chunk_rank3(skv_t)
            bq = chunk_rank4(sq)
            bk = chunk_rank4(sk)
            bv = chunk_rank4(sv)

            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            bo = np.matmul(dots, bv)

            so = unchunk_rank4(bo)
            return so
예제 #8
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
  """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
  depth = np.shape(query)[-1]
  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    dots = np.where(mask, dots, -1e9)
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), 0)
  out = np.matmul(dots, value)
  return out
예제 #9
0
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5,
              center=True, scale=True, **unused_kwargs):
  """Layer construction function for a batch normalization layer."""
  mean = np.mean(x, axis, keepdims=True)
  # Fast but less numerically-stable variance calculation than np.var.
  m1 = np.mean(x**2, axis, keepdims=True)
  var = m1 - mean**2
  # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__
  # with each element of x, resulting in an onp.ndarray with dtype `object`.
  z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype)

  # Expand the parameters to have the right axes.
  beta, gamma = params
  # TODO(phawkins): np.expand_dims should accept an axis tuple.
  # (https://github.com/numpy/numpy/issues/12290)
  ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
  beta = beta[ed]
  gamma = gamma[ed]

  # Return the z rescaled by the parameters if requested.
  if center and scale:
    ret = gamma * z + beta
  elif center:
    ret = z + beta
  elif scale:
    ret = gamma * z
  else:
    ret = z
  assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is '
                                'not the same as the input (%s). Batch norm '
                                'should not change the dtype' %
                                (ret.dtype, x.dtype))
  return ret
예제 #10
0
        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            out_slice = np.matmul(dots, value)
            return out_slice
예제 #11
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        (learning_rate, beta1, decay_rate, clipping_threshold,
         weight_decay_rate, epsilon1, epsilon2) = opt_params
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        return new_params, updates
예제 #12
0
 def _update_diagonal(self, step, g, x, m, v):
     v[0] += g * g
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_g = preconditioner * g
     m = (1 - self._momentum) * preconditioned_g + self._momentum * m
     x = x - self.step_size(step) * m
     return x, (m, v)
예제 #13
0
 def _update_diagonal(self, grads, params, m, v, opt_params):
     (learning_rate, momentum) = opt_params
     v[0] += grads * grads
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_grads = preconditioner * grads
     m = (1 - momentum) * preconditioned_grads + momentum * m
     params = params - (learning_rate * m).astype(params.dtype)
     return params, (m, v)
예제 #14
0
 def update(self, i, g, x, state):
     m, v = state
     b1, b2, eps = self._b1, self._b2, self._eps
     m = (1 - b1) * g + b1 * m  # First  moment estimate.
     v = (1 - b2) * (g**2) + b2 * v  # Second moment estimate.
     mhat = m / (1 - b1**(i + 1))  # Bias correction.
     vhat = v / (1 - b2**(i + 1))
     x = x - self.step_size(i) * mhat / (np.sqrt(vhat) + eps)
     return x, (m, v)
예제 #15
0
 def update(self, step, grads, params, avg_sq_grad, opt_params):
     del step
     learning_rate = opt_params["learning_rate"]
     gamma = opt_params["gamma"]
     eps = opt_params["eps"]
     avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
     params = params - (learning_rate * grads /
                        (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype)
     return params, avg_sq_grad
예제 #16
0
 def update(self, step, grads, params, slots, opt_params):
     m, v = slots
     learning_rate, weight_decay_rate, b1, b2, eps = opt_params
     m = (1 - b1) * grads + b1 * m  # First  moment estimate.
     v = (1 - b2) * (grads**2) + b2 * v  # Second moment estimate.
     mhat = m / (1 - b1**(step + 1))  # Bias correction.
     vhat = v / (1 - b2**(step + 1))
     params = (1 - weight_decay_rate) * params - (
         learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(params.dtype)
     return params, (m, v)
예제 #17
0
  def update(self, i, g, x, state):
    updates = []
    decay_rate = self._decay_rate(i)
    update_scale = self._step_size(i)
    if self._multiply_by_parameter_scale:
      update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2)
    mixing_rate = 1.0 - decay_rate

    g_sqr = g * g + self._epsilon1
    if self._factored and len(x.shape) >= 2:
      v_row = state.pop(0)
      v_col = state.pop(0)
      new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1)
      new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2)
      updates.extend([new_v_row, new_v_col])
      row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
      row_factor = (new_v_row / row_col_mean)**-0.5
      col_factor = (new_v_col)**-0.5
      y = (
          g * np.expand_dims(row_factor, axis=-1) *
          np.expand_dims(col_factor, axis=-2))
    else:
      v = state.pop(0)
      new_v = decay_rate * v + mixing_rate * g_sqr
      updates.append(new_v)
      y = g * (new_v)**-0.5

    if self._clipping_threshold is not None:
      clipping_denom = (
          np.maximum(1.0,
                     np.sqrt(np.mean(y * y)) / self._clipping_threshold))
      y /= clipping_denom

    subtrahend = update_scale * y
    if self._beta1:
      m = state.pop(0)
      new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend
      subtrahend = new_m
      updates.append(new_m)

    new_x = x - subtrahend
    return new_x, updates
예제 #18
0
    def call(self, x, params, state, **unused_kwargs):
        """Layer construction function for a batch normalization layer."""

        running_mean, running_var, num_batches = state

        if self._mode == 'train':
            mean = np.mean(x, self._axis, keepdims=True)
            # Fast but less numerically-stable variance calculation than np.var.
            m1 = np.mean(x**2, self._axis, keepdims=True)
            var = m1 - mean**2
            num_batches = num_batches + 1
            if self._momentum is None:
                # A simple average over all batches seen so far
                exponential_average_factor = 1.0 / num_batches
            else:
                exponential_average_factor = self._momentum

            def average(factor, new, old):
                return (factor * new + (1 - factor) * old).astype(old.dtype)

            running_mean = average(exponential_average_factor, mean,
                                   running_mean)
            running_var = average(exponential_average_factor, var, running_var)
            state = (running_mean, running_var, num_batches)
        else:
            mean = running_mean
            var = running_var

        z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype(
            x.dtype)

        # Expand the parameters to have the right axes.
        beta, gamma = params
        # TODO(phawkins): np.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in self._axis else slice(None)
                   for i in range(np.ndim(x)))
        beta = beta[ed]
        gamma = gamma[ed]

        # Return the z rescaled by the parameters if requested.
        if self._center and self._scale:
            output = gamma * z + beta
        elif self._center:
            output = z + beta
        elif self._scale:
            output = gamma * z
        else:
            output = z
        assert output.dtype == x.dtype, (
            'The dtype of the output (%s) of batch '
            'norm is not the same as the input (%s). '
            'Batch norm should not change the dtype' % (output.dtype, x.dtype))
        return output, state
예제 #19
0
 def learning_rate(step):  # pylint: disable=invalid-name
   """Step to learning rate function."""
   ret = 1.0
   for name in factors:
     if name == "constant":
       ret *= constant
     elif name == "linear_warmup":
       ret *= np.minimum(1.0, step / warmup_steps)
     elif name == "rsqrt_decay":
       ret /= np.sqrt(np.maximum(step, warmup_steps))
     else:
       raise ValueError("Unknown factor %s." % name)
   return ret
예제 #20
0
 def apply_fun(params, x, **kwargs):
   beta, gamma = params
   # TODO(phawkins): np.expand_dims should accept an axis tuple.
   # (https://github.com/numpy/numpy/issues/12290)
   ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
   beta = beta[ed]
   gamma = gamma[ed]
   mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
   z = (x - mean) / np.sqrt(var + epsilon)
   if center and scale: return gamma * z + beta
   if center: return z + beta
   if scale: return gamma * z
   return z
예제 #21
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5
     v1 = 47.9167
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 0)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1)
     onp.testing.assert_allclose(state[1], v1, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
     inp2 = inp1 * 2 + 3
     m2 = m1 * 2 + 3
     v2 = v1 * 4
     m12 = (m1 + m2) / 2
     v12 = (v1 + v2) / 2
     out, state = layer(inp2, params, state)
     onp.testing.assert_allclose(state[0], m12)
     onp.testing.assert_allclose(state[1], v12, rtol=1e-6)
     self.assertEqual(state[2], 2)
     onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps),
                                 rtol=1e-6)
     layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval")
     inp3 = inp1 * 5 + 7
     out, state_unchanged = layer(inp3, params, state)
     for i in range(3):
         onp.testing.assert_allclose(state_unchanged[i], state[i])
     onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps),
                                 rtol=1e-6)
예제 #22
0
        def binned_attn(sqk, sv):  # pylint: disable=invalid-name
            """Performs attention on sorted queries/keys/values."""
            # Split off a "bin" axis so that attention only occurs whithin chunks.
            bq_t = bkv_t = chunk_scalars(sjoint_t)
            bqk = chunk_vectors(sqk)
            bv = chunk_vectors(sv)

            # Hashing operates on unit-length vectors. Unnormalized query vectors are
            # fine because they effectively provide a learnable temperature for the
            # attention softmax, but normalizing keys is needed so that similarity for
            # the purposes of attention correctly corresponds to hash locality.
            bq = bqk
            bk = self.make_unit_length(bqk)

            # Allow each chunk to attend within itself, and also one chunk back. Chunk
            # boundaries might occur in the middle of a sequence of items from the
            # same bin, so this increases the chances of attending to relevant items.
            # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
            bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]],
                                      axis=1)
            bk = np.concatenate([bk, bk_extra], axis=2)
            bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]],
                                      axis=1)
            bv = np.concatenate([bv, bv_extra], axis=2)
            bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                         axis=1)
            bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

            # Dot-product attention.
            dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(
                bq.shape[-1])

            # Causal masking
            mask = jax.lax.convert_element_type(
                jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]),
                np.float32)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 32 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))
            bo = np.matmul(dots, bv)

            so = unchunk_vectors(bo)
            return so
예제 #23
0
        def forward_slice(query_slice, q_loop_idx, key, value):
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            out_slice = np.matmul(dots, value)
            return out_slice
예제 #24
0
 def learning_rate(step):  # pylint: disable=invalid-name
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == "constant":
             ret *= constant
         elif name == "linear_warmup":
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == "rsqrt_decay":
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == "decay_every":
             ret *= (decay_factor**(step // steps_per_decay))
         else:
             raise ValueError("Unknown factor %s." % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {"learning_rate": ret}
예제 #25
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     eps = 1e-5
     rng = backend.random.get_prng(0)
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5  # Mean of this random input.
     v1 = 47.9167  # Variance of this random input.
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     params, state = layer.initialize(input_shape, input_dtype, rng)
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 1)
     self.assertEqual(state[2], 0)
     out, state = layer(inp1, params, state)
     onp.testing.assert_allclose(state[0], m1 * 0.001)
     onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
예제 #26
0
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   (learning_rate, momentum) = opt_params
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
예제 #27
0
    def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
      """Forward pass for a subset of the query vectors."""
      if self._share_qk:
        key = self.make_unit_length(key)

      dots = np.matmul(
          query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth)

      # Causal masking
      mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
      dots = dots - 1e9 * mask

      # Mask out attention to self except when no other targets are available.
      if self._share_qk:
        self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
        dots = dots - 1e5 * self_mask

      # Softmax.
      dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))

      if self.dropout is not None and self.dropout > 0.0:
        # Dropout is broadcast across the batch+head dimension
        dropout_shape = (1, dots.shape[-2], dots.shape[-1])
        slice_rng = jax.random.fold_in(rng, q_loop_idx)
        keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
        keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
        dots = dots * multiplier

      if self._hard_k > 0:
        top_k = np.sort(dots)[..., -self._hard_k]  # Get the top-kth weight.
        top_k = jax.lax.stop_gradient(top_k)
        dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
        dots = np.maximum(dots, 0)
        dots_sum = np.sum(dots, axis=-1, keepdims=True)  # Re-normalize.
        dots /= dots_sum  # Re-normalize.

      out_slice = np.matmul(dots, value)
      return out_slice
예제 #28
0
 def _update_sketched(self, step, g, x, m, v):
     """Update for higher-rank parameters."""
     shape = x.shape
     rank = len(shape)
     reshaped_accumulators = [
         np.reshape(v[i], self._expanded_shape(shape, i))
         for i in range(rank)
     ]
     current_accumulator = self._minimum(reshaped_accumulators)
     current_accumulator += g * g
     accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                     1.0 / np.sqrt(current_accumulator),
                                     np.zeros_like(current_accumulator))
     preconditioned_gradient = g * accumulator_inv_sqrt
     m = (1.0 -
          self._momentum) * preconditioned_gradient + self._momentum * m
     x = x - self.step_size(step) * m
     for i in range(len(v)):
         axes = list(range(int(i))) + list(range(int(i) + 1, rank))
         dim_accumulator = np.amax(current_accumulator, axis=axes)
         v[i] = dim_accumulator
     return x, (m, v)
예제 #29
0
    def call(self, inputs, params=(), state=(), rng=None, **kwargs):
        del params, kwargs
        # We use the same vector as both a query and a key. For now we haven't
        # adjusted any of the surrounding code, so we still get a separate "key"
        # input that we ignore.
        qk, _, v = inputs
        seqlen = qk.shape[-2]

        # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head
        # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops?
        qk = np.tile(qk, (self.n_hashes, 1, 1))
        v = np.tile(v, (self.n_hashes, 1, 1))

        # bins are n_hashes*n_batch*n_heads, seqlen
        # They specify which hash bucket the query/key/value vectors fall in.
        bins = self.hash_vectors(qk, rng=rng)

        # joint_t is n_hashes*n_batch*n_heads, seqlen
        joint_t = jax.lax.tie_in(qk, np.arange(seqlen))
        joint_t = np.reshape(joint_t, (1, seqlen))
        joint_t = np.broadcast_to(joint_t, qk.shape[:-1])

        assert int(
            (self.n_buckets_per_bin * self.n_bins + 1) * seqlen
        ) < 2**31, (
            'Potential 32-bit integer overflow; please double-check the code.')
        joint_bins_and_t = seqlen * bins + joint_t

        def chunk_scalars(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1))

        def chunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1]))

        def unchunk_vectors(x):  # pylint: disable=invalid-name
            return np.reshape(x, (x.shape[0], -1, x.shape[-1]))

        # Sort everything by bin number, with a secondary sort by time
        # (variables starting with "s" are sorted)
        _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t,
                                           joint_t,
                                           dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1)
        # TODO(kitaev): why does jax flag integer indices as differentiable?
        # If we don't call stop_gradient here, custom gradients below won't work
        # because the primitive functions close over "differentiable" variables.
        sjoint_t = jax.lax.stop_gradient(sjoint_t)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        # The backward pass of gather is in general a scatter operation, but we know
        # we're dealing with permutations so we use gather for the backward pass
        # too. This custom gradient should be about 2x faster than having jax infer
        # one that uses scatter ops instead.
        def permute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)

        def unpermute_impl(vecs):
            assert len(vecs.shape) == 3
            return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)

        @jax.custom_transforms
        def permute(vecs):
            return permute_impl(vecs)

        def permute_vjp(vecs):
            out_vecs = permute_impl(vecs)

            def vjpfun(grad):
                return (unpermute_impl(grad), )

            return out_vecs, vjpfun

        @jax.custom_transforms
        def unpermute(vecs):
            return unpermute_impl(vecs)

        def unpermute_vjp(vecs):
            out_vecs = unpermute_impl(vecs)

            def vjpfun(grad):
                return (permute_impl(grad), )

            return out_vecs, vjpfun

        jax.defvjp_all(permute, permute_vjp)
        jax.defvjp_all(unpermute, unpermute_vjp)

        sqk = permute(qk)
        sv = permute(v)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = chunk_scalars(sjoint_t)
        bqk = chunk_vectors(sqk)
        bv = chunk_vectors(sv)

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bin, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1)
        bk = np.concatenate([bk, bk_extra], axis=2)
        bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1)
        bv = np.concatenate([bv, bv_extra], axis=2)
        bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]],
                                     axis=1)
        bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
        self_mask = jax.lax.tie_in(dots, self_mask)
        dots = dots - 32 * self_mask

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._hard_k > 0:
            top_k = np.sort(dots)[...,
                                  -self._hard_k]  # Get the top-kth weight.
            top_k = jax.lax.stop_gradient(top_k)
            dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
            dots = np.maximum(dots, 0)
            dots_sum = np.sum(dots, axis=-1,
                              keepdims=True)  # Sum to re-normalize.
            dots_logsumexp += np.log(dots_sum)  # Add it to the weight.
            dots /= dots_sum  # Re-normalize.

        bo = np.matmul(dots, bv)
        so = unchunk_vectors(bo)
        slogits = unchunk_vectors(dots_logsumexp)

        o = unpermute(so)
        logits = unpermute(slogits)

        o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1]))
        logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1))
        probs = np.exp(logits -
                       backend.logsumexp(logits, axis=0, keepdims=True))
        out = np.sum(o * probs, axis=0)
        assert out.shape == inputs[2].shape

        return out, state
예제 #30
0
 def make_unit_length(self, x, epsilon=1e-6):
     variance = np.mean(x**2, axis=-1, keepdims=True)
     norm_inputs = x / np.sqrt(variance + epsilon)
     return norm_inputs