Example #1
0
def NewPositionalEncoding(x, positions=None, **kwargs):
    """Implements new positional encoding."""
    del kwargs
    x_length = np.shape(x)[1]
    pos = np.array(positions)[np.newaxis, :x_length, :]
    pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
    res = np.concatenate([x, pos], axis=2)
    return res
def ChunkedPositionalEncoding(x, params, **unused_kwargs):
    """Implements bare positional encoding."""
    if not isinstance(x, (list, tuple)):  # non-chunked inputs
        symbol_size = np.shape(x)[1]
        return x + params[:, :symbol_size, :]
    # Chunked case: apply to all chunks selecting as much as needed.
    offset = 0
    results = []
    for chunk in x:
        symbol_size = np.shape(chunk)[1]
        results.append(chunk + params[:, offset:offset + symbol_size, :])
        offset += symbol_size
    return results
Example #3
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Example #4
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
  """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
  depth = np.shape(query)[-1]
  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    dots = np.where(mask, dots, -1e9)
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), 0)
  out = np.matmul(dots, value)
  return out
Example #5
0
def PureAttention(x, params, n_heads=1, dropout=0.0, mode='train', **kwargs):
  """Pure transformer-style multi-headed attention.

  Args:
    x: inputs (q, k, v, mask)
    params: parameters (none)
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'
    **kwargs: other arguments including the rng

  Returns:
    Pure Multi-headed attention result, and the mask.
  """
  del params
  rng = kwargs.get('rng', None)
  q, k, v, mask = x
  d_feature = q.shape[-1]
  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads
  nbatch = np.shape(q)[0]
  # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head
  def SplitHeads(x):
    return np.transpose(
        np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
  # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature
  def JoinHeads(x):  # pylint: disable=invalid-name
    return np.reshape(
        np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
  # Split heads, dot-product attention, rejoin heads.
  res = JoinHeads(
      DotProductAttention(
          SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,
          dropout=dropout, mode=mode, rng=rng))
  return res, mask  # Keep the mask.
Example #6
0
    def apply_fun(params, inputs, **kwargs):  # pylint: disable=missing-docstring
        del params
        rng = kwargs.get('rng', None)
        q, k, v, mask = inputs
        assert feature_depth % num_heads == 0
        head_depth = feature_depth // num_heads
        nbatch = np.shape(q)[0]

        # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth
        def split_heads(x):
            return np.transpose(
                np.reshape(x, (nbatch, -1, num_heads, head_depth)),
                (0, 2, 1, 3))

        # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth
        def join_heads(x):
            return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                              (nbatch, -1, num_heads * head_depth))

        # Split heads, dot-product attention, rejoin heads.
        return join_heads(
            dot_product_attention(split_heads(q),
                                  split_heads(k),
                                  split_heads(v),
                                  mask,
                                  dropout=dropout,
                                  mode=mode,
                                  rng=rng))
Example #7
0
def dot_product_attention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate - keep probability
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        dots = np.where(mask, dots, -1e9)
    dots = stax.softmax(dots, axis=-1)
    if dropout is not None and mode == 'train':
        keep = random.bernoulli(rng, dropout, dots.shape)
        dots = np.where(keep, dots / dropout, 0)
    out = np.matmul(dots, value)
    return out
def SplitHeads(x, params, n_heads=1, **kwargs):
    del params, kwargs
    d_model = x.shape[-1]
    assert d_model % n_heads == 0
    d_head = d_model // n_heads
    n_batch = np.shape(x)[0]
    # n_batch, seqlen, d_model --> n_batch, n_heads, seqlen, d_head
    return np.transpose(np.reshape(x, (n_batch, -1, n_heads, d_head)),
                        (0, 2, 1, 3))
 def forward(self, inputs, params=(), state=(), **kwargs):
   if self._mode in ('train', 'eval'):
     x = inputs
     symbol_size = np.shape(x)[1]
     return (x + params[:, :symbol_size, :], state)
   else:
     assert self._mode == 'predict'
     # Fast inference: return consectutive elements of the encoding sequence,
     # storing the index in state.
     return (inputs + np.expand_dims(params[:, state, :], 1), state + 1)
Example #10
0
def PureMultiHeadedAttention(params,
                             x,
                             feature_depth=None,
                             num_heads=8,
                             dropout=0.0,
                             mode='train',
                             **kwargs):
    """Pure transformer-style multi-headed attention.

  Args:
    params: parameters (none)
    x: inputs (q, k, v, mask)
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'
    **kwargs: other arguments including the rng

  Returns:
    Pure Multi-headed attention layer. (No Dense transforms on input.)
  """
    del params
    rng = kwargs.get('rng', None)
    q, k, v, mask = x
    assert feature_depth % num_heads == 0
    head_depth = feature_depth // num_heads
    nbatch = np.shape(q)[0]

    # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth
    def SplitHeads(x):
        return np.transpose(np.reshape(x, (nbatch, -1, num_heads, head_depth)),
                            (0, 2, 1, 3))

    # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth
    def JoinHeads(x):  # pylint: disable=invalid-name
        return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                          (nbatch, -1, num_heads * head_depth))

    # Split heads, dot-product attention, rejoin heads.
    return JoinHeads(
        DotProductAttention(SplitHeads(q),
                            SplitHeads(k),
                            SplitHeads(v),
                            mask,
                            dropout=dropout,
                            mode=mode,
                            rng=rng))
Example #11
0
def PositionalEncoding(x, params, **unused_kwargs):
    """Implements bare positional encoding."""
    symbol_size = np.shape(x)[1]
    return x + params[:, :symbol_size, :]
Example #12
0
    def call_and_grad(self, inputs, ct, rng=None, **kwargs):
        del kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
    def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
        # This is the core of the memory-efficient attention implementation, where
        # we use the jax.lax.while_loop primitive to compute attention for a small
        # set of query positions at a time. Note how in the backwards pass, we
        # compute both the forward direction (to recover the previous layer's
        # activations) and the backward direction simultaneously. This allows us to
        # only use a single loop, where the inner portion of the loop does a slice
        # of the forward+backward joint computation. Unfortunately we have had to
        # introduce a large number of wrapper classes (including
        # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole
        # purpose of connecting this implementation of forward_and_vjp with the core
        # backprop implementation.

        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        def make_mask(N, M, k):
            x = np.arange(N, dtype=np.int32)
            y = np.arange(M, dtype=np.int32)
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):
            """Forward pass for a subset of the query vectors."""
            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Softmax.
            dots = np.exp(dots - dots.max(axis=-1, keepdims=True))
            dots = dots / dots.sum(axis=-1, keepdims=True)
            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):
            output_slice, vjpfun = jax.vjp(forward_slice, query_slice,
                                           q_loop_idx, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[2]
        q_loop_stride = self._loop_stride
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=2)
                # ignore partial_ct[1], which is wrt the loop idx
                key_ct_accum = key_ct_accum + partial_ct[2]
                value_ct_accum = value_ct_accum + partial_ct[3]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
def JoinHeads(x, params, **kwargs):
    del params, kwargs
    n_batch = np.shape(x)[0]
    seqlen = np.shape(x)[2]
    # n_batch, n_heads, seqlen, d_head --> n_batch, seqlen, d_model
    return np.reshape(np.transpose(x, (0, 2, 1, 3)), (n_batch, seqlen, -1))
Example #15
0
 def apply_fun(params, inputs, **kwargs):
     del kwargs
     pe = params
     symbol_size = np.shape(inputs)[1]
     return inputs + pe[:, :symbol_size]