Beispiel #1
0
 def init(self, x):
   shape = x.shape
   state = []
   if self._factored and len(shape) >= 2:
     v_row = np.zeros(shape[:-1], dtype=np.float32)
     v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
     state.extend([v_row, v_col])
   else:
     v = np.zeros_like(x)
     state.append(v)
   if self._beta1:
     m = np.zeros_like(x)
     state.append(m)
   return state
Beispiel #2
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
def LayerNormParams(input_shape, input_dtype, rng, epsilon=1e-6):
    """Helper: create layer norm parameters."""
    del input_dtype, rng, epsilon
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Beispiel #4
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Makes encoder-decoder mask from decoder input and a padding mask."""
    decoder_input, padding_mask = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
 def test_computes_mean_with_weights(self, backend_name):
     with backend.use_backend(backend_name):
         inputs = [np.array([1, 2, 3])]
         targets = [np.zeros(3)]
         weights = [np.array([3, 1, 0])]
         mean = trax.masked_mean(inputs, targets, weights)
         onp.testing.assert_allclose(mean, 1.25)
def _layer_norm_params(input_shape, input_dtype, rng):
    """Helper: create layer norm parameters."""
    del input_dtype, rng
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Beispiel #7
0
def _layer_norm_new_params(input_shape, rng, epsilon=1e-6):  # pylint: disable=invalid-name
    """Helper: create layer norm parameters."""
    del rng, epsilon
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Beispiel #8
0
def NewPositionalEncoding(x, positions=None, **kwargs):
    """Implements new positional encoding."""
    del kwargs
    x_length = np.shape(x)[1]
    pos = np.array(positions)[np.newaxis, :x_length, :]
    pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
    res = np.concatenate([x, pos], axis=2)
    return res
Beispiel #9
0
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2),
                           center=True, scale=True, **kwargs):
  """Helper to initialize batch norm params."""
  del rng, kwargs
  axis = (axis,) if np.isscalar(axis) else axis
  shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
  beta = np.zeros(shape, dtype='float32') if center else ()
  gamma = np.ones(shape, dtype='float32') if scale else ()
  return (beta, gamma)
    def new_parameters(self, input_shape, input_dtype, rng):
        """Helper to initialize batch norm params."""
        del input_dtype, rng
        axis = self._axis
        axis = (axis, ) if np.isscalar(axis) else axis
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta = np.zeros(shape, dtype='float32') if self._center else ()
        gamma = np.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = np.zeros(stats_shape, dtype=np.float32)
        running_var = np.ones(stats_shape, dtype=np.float32)
        num_batches = np.zeros((), dtype=np.int32)
        return (beta, gamma), (running_mean, running_var, num_batches)
def _fast_inference_init_state(input_shapes, input_dtypes, buffer_length):
  """Initializes state of a causal attention layer for fast inference."""
  ((batch_size, _, _), _, _) = input_shapes
  def init_buffer(shape, dtype):
    (_, _, depth) = shape
    return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
  (_, k, v) = tuple(
      init_buffer(shape, dtype)
      for (shape, dtype) in zip(input_shapes, input_dtypes)
  )
  mask = np.zeros((batch_size, 1, buffer_length))
  index = 0
  state = (k, v, mask, index)
  return state
Beispiel #12
0
 def init(self, x):
     vs = [np.zeros(sz, dtype=x.dtype) for sz in x.shape]
     return (np.zeros_like(x), vs)
    def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
        # This is the core of the memory-efficient attention implementation, where
        # we use the jax.lax.while_loop primitive to compute attention for a small
        # set of query positions at a time. Note how in the backwards pass, we
        # compute both the forward direction (to recover the previous layer's
        # activations) and the backward direction simultaneously. This allows us to
        # only use a single loop, where the inner portion of the loop does a slice
        # of the forward+backward joint computation. Unfortunately we have had to
        # introduce a large number of wrapper classes (including
        # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole
        # purpose of connecting this implementation of forward_and_vjp with the core
        # backprop implementation.

        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        def make_mask(N, M, k):
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):
            output_slice, vjpfun = jax.vjp(forward_slice, query_slice,
                                           q_loop_idx, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=2)
                # ignore partial_ct[1], which is wrt the loop idx
                key_ct_accum = key_ct_accum + partial_ct[2]
                value_ct_accum = value_ct_accum + partial_ct[3]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Beispiel #14
0
 def init(self, params):
     vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
     return (np.zeros_like(params), vs)
 def init_buffer(shape, dtype):
   (_, _, depth) = shape
   return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
Beispiel #16
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Beispiel #17
0
 def init_fun(_, input_shape):
     features = input_shape[-1]
     scale = np.ones(features)
     bias = np.zeros(features)
     return input_shape, (scale, bias)
  def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True,
                             rng=None):
    assert return_output or ct is not None, 'No work to perform!'
    # pylint: disable=protected-access
    stash_buckets = (return_output and ct is None
                     and base.Layer._STASH_IN is not None)
    if return_output and ct is not None and base.Layer._STASH_OUT is not None:
      buckets = base.Layer._STASH_OUT.pop(self)
    else:
      buckets = None
    # pylint: enable=protected-access

    # The approach here is to perform attention for one batch element and head
    # at a time. Note that there is absolutely no interaction across examples or
    # heads: this layer has no parameters, and hashing patterns are also
    # different across examples/heads. As a result, batching doesn't give any
    # performance gains except in the case of accelerator under-utilization. We
    # assume that hash-based attention will be applied primarily to long
    # sequences, where unbatched attention for a single head has sufficient
    # computation to fill up the accelerator.

    batch_loop_idx = np.zeros((), dtype=np.int32)
    batch_loop_max = qk.shape[0]

    init_vals = (batch_loop_idx,)
    if return_output:
      out_accum = np.zeros_like(qk)
      init_vals = init_vals + (out_accum,)
    if stash_buckets:
      buckets_accum = np.zeros(
          [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
      init_vals = init_vals + (buckets_accum,)
    if ct is not None:
      qk_ct_accum = np.zeros_like(qk)
      v_ct_accum = np.zeros_like(v)
      init_vals = init_vals + (qk_ct_accum, v_ct_accum)

    def cond_fun(vals):
      batch_loop_idx = vals[0]
      return jax.lax.lt(batch_loop_idx, batch_loop_max)

    def body_fun(vals):
      """Performs attention for a single batch element and head."""
      batch_loop_idx = vals[0]
      if self._prng is None:
        hash_rng = jax.random.fold_in(rng, batch_loop_idx)
      else:
        # TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
        hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
      qk_slice = jax.lax.dynamic_index_in_dim(
          qk, batch_loop_idx, axis=0, keepdims=False)
      v_slice = jax.lax.dynamic_index_in_dim(
          v, batch_loop_idx, axis=0, keepdims=False)

      if buckets is None:
        buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
      else:
        buckets_slice = jax.lax.dynamic_index_in_dim(
            buckets, batch_loop_idx, axis=0, keepdims=False)

      if ct is None:
        out_slice = self.single_call(
            qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
      else:
        def _do_single_call(qk_slice, v_slice):
          return self.single_call(
              qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
        ct_slice = jax.lax.dynamic_index_in_dim(
            ct, batch_loop_idx, axis=0, keepdims=False)
        out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
        qk_ct_slice, v_ct_slice = vjpfun(ct_slice)

      new_vals = (batch_loop_idx + 1,)
      if return_output:
        out_accum = vals[1]
        out_accum = jax.lax.dynamic_update_index_in_dim(
            out_accum, out_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (out_accum,)
      if stash_buckets:
        buckets_accum = vals[2]
        buckets_accum = jax.lax.dynamic_update_index_in_dim(
            buckets_accum, buckets_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (buckets_accum,)
      if ct is not None:
        qk_ct_accum, v_ct_accum = vals[-2:]
        qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
            qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
        v_ct_accum = jax.lax.dynamic_update_index_in_dim(
            v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
        new_vals = new_vals + (qk_ct_accum, v_ct_accum)

      return new_vals

    final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

    if return_output:
      out = final_vals[1]
    else:
      out = None

    if stash_buckets:
      base.Layer._STASH_IN[self] = final_vals[2]  # pylint: disable=protected-access

    if ct is not None:
      input_ct = final_vals[-2:]
    else:
      input_ct = None

    return out, input_ct
Beispiel #19
0
    return init


def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
    """An initializer function for random Glorot-scaled coefficients."""
    def init(rng, shape):
        fan_in, fan_out = shape[in_dim], shape[out_dim]
        size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
        std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
        return (std * backend.random.normal(rng, shape)).astype('float32')

    return init


zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
ones = lambda rng, shape: np.ones(shape, dtype='float32')

# Layers

# Each layer constructor function returns an (init_fun, apply_fun) pair, where
#   init_fun: takes an input shape and returns an (output_shape, params) pair,
#   apply_fun: takes params, inputs, and an rng key and applies the layer.


def Dense(out_dim, W_init=glorot(), b_init=randn()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        w, b = W_init(rng,
                      (input_shape[-1], out_dim)), b_init(rng, (out_dim, ))
 def dummy_loss_fn(params):
     inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2
     output = model(inputs, params=params, state=state, rng=rng)
     dummy_loss = backend.numpy.sum(output[0])
     return dummy_loss