コード例 #1
0
ファイル: attention.py プロジェクト: galloperx/trax
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
コード例 #2
0
 def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE, rng=None):
   if self._mode in ('train', 'eval'):
     x = inputs
     symbol_size = jnp.shape(x)[1]
     px = weights[:, :symbol_size, :]
     if self._dropout == 0:
       return (x + px, state)
     else:
       noise_shape = list(px.shape)
       for dim in self._dropout_broadcast_dims:
         noise_shape[dim] = 1
       keep_prob = 1.0 - self._dropout
       if math.backend_name() == 'jax':
         keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype))
       keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
       multiplier = keep.astype(x.dtype) / keep_prob
       return (x + px * multiplier, state)
   else:
     assert self._mode == 'predict'
     assert self._dropout == 0
     # State in this class is only used for fast inference. In that case,
     # the model is called with consecutive elements position-by-position.
     # This positional encoding layer needs to store the index of the current
     # position then and increment it on each call -- that's how state is used
     # and updated below.
     if inputs.shape[1] == 1:
       return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1)
     else:
       emb = []
       for i in range(inputs.shape[0]):
         emb.append(jax.lax.dynamic_slice_in_dim(
             weights[0], state[i], inputs.shape[1], axis=0))
       return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
コード例 #3
0
  def forward_with_state(self, x, weights, state, rng):
    """Pure transformer-style multi-headed attention.

    Args:
      x: inputs (q, k, v, mask)
      weights: parameters (none)
      state: parameters (none)
      rng: Single-use random number generator (JAX PRNG key).

    Returns:
      Pure Multi-headed attention result, and the mask.
    """
    del weights
    n_heads, dropout, mode = self._n_heads, self._dropout, self._mode
    q, k, v, mask = x
    d_feature = q.shape[-1]
    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads
    nbatch = jnp.shape(q)[0]
    # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head
    def SplitHeads(x):
      return jnp.transpose(
          jnp.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
    # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature
    def JoinHeads(x):  # pylint: disable=invalid-name
      return jnp.reshape(
          jnp.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
    # Split heads, dot-product attention, rejoin heads.
    res = JoinHeads(
        DotProductAttention(
            SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,
            dropout=dropout, mode=mode, rng=rng))
    return (res, mask), state  # Keep the mask.
コード例 #4
0
  def forward(self, inputs, weights):
    gamma, beta, epsilon_l = weights

    epsilon = self._init_epsilon
    if epsilon_l is not base.EMPTY_WEIGHTS:
      epsilon += np.abs(epsilon_l[0])

    # Omit B and C
    axis = tuple(range(1, len(np.shape(inputs)) - 1))
    # (B, 1, 1, C)
    nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
    # (B, W, H, C)
    xhat = inputs / np.sqrt(nu2 + epsilon)

    return gamma * xhat + beta
コード例 #5
0
ファイル: attention.py プロジェクト: yangcaot/trax
    def forward(self, inputs):
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if math.backend_name() == 'jax':
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = math.random.bernoulli(self.rng, keep_prob,
                                             tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
コード例 #6
0
ファイル: position_encodings.py プロジェクト: zongdaofu/trax
 def forward_with_state(self,
                        x,
                        weights=layer_base.EMPTY_WEIGHTS,
                        state=layer_base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     length = np.shape(x)[1]
     max_pos = self._base**self._n_digits
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = np.arange(0, length)
     if self._mode == 'train':
         positions += jax.random.randint(rng, (), 0, max_pos - length)
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
         cur_indices = np.mod(cur_positions, self._base)
         cur_positions //= self._base
         pos_embeddings.append(np.take(weights[i], cur_indices, axis=0))
     embeddings = np.concatenate(pos_embeddings, axis=-1)
     return (x + embeddings[None, :, :], state)
コード例 #7
0
def NewPositionalEncoding(x, positions=None):
  """Implements new positional encoding."""
  x_length = np.shape(x)[1]
  pos = np.array(positions)[np.newaxis, :x_length, :]
  pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
  return pos