def NewPositionalEncoding(x, positions=None, **kwargs):
  """Implements new positional encoding."""
  del kwargs
  x_length = np.shape(x)[1]
  pos = np.array(positions)[np.newaxis, :x_length, :]
  pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
  return pos
Exemple #2
0
 def Initializer(shape, rng):
     del rng
     logging.info('Loading pretrained embeddings from %s', path)
     with tf.io.gfile.GFile(path, 'rb') as f:
         parameters = np.load(f)
     assert np.shape(parameters) == shape, ('Expected shape %s, got %s' %
                                            (shape, np.shape(parameters)))
     return parameters
Exemple #3
0
 def forward_with_state(self,
                        inputs,
                        weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     if self._mode in ('train', 'eval'):
         x = inputs
         symbol_size = np.shape(x)[1]
         px = weights[:, :symbol_size, :]
         if self._dropout == 0:
             return (x + px, state)
         else:
             noise_shape = list(px.shape)
             for dim in self._dropout_broadcast_dims:
                 noise_shape[dim] = 1
             keep_prob = 1.0 - self._dropout
             if backend.get_name() == 'jax':
                 keep_prob = jax.lax.tie_in(
                     x, np.full((), keep_prob, dtype=x.dtype))
             keep = backend.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
             multiplier = keep.astype(x.dtype) / keep_prob
             return (x + px * multiplier, state)
     else:
         assert self._mode == 'predict'
         assert self._dropout == 0
         # State in this class is only used for fast inference. In that case,
         # the model is called with consecutive elements position-by-position.
         # This positional encoding layer needs to store the index of the current
         # position then and increment it on each call -- that's how state is used
         # and updated below.
         return (inputs + np.expand_dims(weights[:, state, :], 1),
                 state + 1)
Exemple #4
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Exemple #5
0
    def forward(self, inputs, weights):
        gamma, beta, epsilon_l = weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += np.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(np.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / np.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
Exemple #6
0
    def forward_with_state(self, x, weights, state, rng):
        """Pure transformer-style multi-headed attention.

    Args:
      x: inputs (q, k, v, mask)
      weights: parameters (none)
      state: parameters (none)
      rng: random number generator

    Returns:
      Pure Multi-headed attention result, and the mask.
    """
        del weights
        n_heads, dropout, mode = self._n_heads, self._dropout, self._mode
        q, k, v, mask = x
        d_feature = q.shape[-1]
        assert d_feature % n_heads == 0
        d_head = d_feature // n_heads
        nbatch = np.shape(q)[0]

        # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head
        def SplitHeads(x):
            return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)),
                                (0, 2, 1, 3))

        # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature
        def JoinHeads(x):  # pylint: disable=invalid-name
            return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                              (nbatch, -1, n_heads * d_head))

        # Split heads, dot-product attention, rejoin heads.
        res = JoinHeads(
            DotProductAttention(SplitHeads(q),
                                SplitHeads(k),
                                SplitHeads(v),
                                mask,
                                dropout=dropout,
                                mode=mode,
                                rng=rng))
        return (res, mask), state  # Keep the mask.
    def forward_and_backward(self,
                             inputs,
                             ct,
                             state=base.EMPTY_STATE,
                             new_state=base.EMPTY_STATE,
                             rng=None,
                             **kwargs):
        del state, new_state, kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def make_self_mask(N, M, k):  # pylint: disable=invalid-name
            """Masks out elements attending to self.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.eq((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            if self._share_qk:
                key = self.make_unit_length(key)

            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            if self._share_qk:
                self_mask = make_self_mask(dots.shape[-2], dots.shape[-1],
                                           q_loop_idx)
                dots = dots - 1e5 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            if self._hard_k > 0:
                top_k = np.sort(dots)[...,
                                      -self._hard_k]  # Get the top-kth weight.
                top_k = jax.lax.stop_gradient(top_k)
                dots -= top_k[...,
                              np.newaxis]  # Subtract (be 0 for lower ones).
                dots = np.maximum(dots, 0)
                dots_sum = np.sum(dots, axis=-1,
                                  keepdims=True)  # Re-normalize.
                dots /= dots_sum  # Re-normalize.

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        if q_loop_max == 1:  # For abstract runs with unknown shapes.
            q_loop_stride = 1
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]