Exemplo n.º 1
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
Exemplo n.º 2
0
 def gather_fn(x):
   """Gather slices for a single tensor."""
   if x.ndim == 0:  # ignore scalars (e.g. cache index)
     return x
   elif x.shape[0] != batch_size:
     assert x.shape[0] % batch_size == 0
     res = x.reshape((batch_size, -1,) + x.shape[1:])
     res = np.swapaxes(res, 1, 2)
     res = res[batch_indices, beam_indices]
     res = np.swapaxes(res, 1, 2)
     res = res.reshape((-1,) + res.shape[2:])
     return res
   else:
     return x[batch_indices, beam_indices]
Exemplo n.º 3
0
 def significance_weights(mask):
   # (repr,) -> (batch, length, repr)
   # significance = [0, 1, 2]
   significance = serializer.significance_map
   assert significance.shape[0] * 2 == mask.shape[2]
   significance = jnp.repeat(significance[jnp.newaxis, ...], repeats=2, axis=0)
   # significance = [0, 1, 2, 0, 1, 2]
   significance = jnp.concatenate(significance, axis=0)
   assert significance.shape[0] == mask.shape[2]
   # significance = batch_size * [0, 1, 2, 0, 1, 2]
   significance = jnp.repeat(
       significance[np.newaxis, ...], repeats=mask.shape[0], axis=0)
   # significance = batch_size * [0, 1, 2, 0, 1, 2] * mask.shape[1]
   significance = jnp.repeat(
       significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2)
   # significance = batch_size *  mask.shape[1] * [0, 1, 2, 0, 1, 2]
   significance = jnp.swapaxes(significance, 1, 2)
   assert significance.shape == mask.shape
   sig_weights = mask * decay ** significance
   batch_size = sig_weights.shape[0]
   mask_size = sig_weights.shape[1]*sig_weights.shape[2]
   # TODO(henrykm): Make sure that the reshape works in the desired way
   sig_weights = np.reshape(sig_weights, (batch_size, mask_size))
   # Alternatively we also can do something like
   # sig_weights = jnp.concatenate(sig_weights, axis=1)
   # sig_weights = jnp.concatenate(sig_weights, axis=0)
   # sig_weights = jnp.reshape(sig_weights, (batch_size, mask_size))
   return sig_weights
Exemplo n.º 4
0
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    if mask is not None:  # The 'None' in this line does not need to be replaced
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)
    attention = jnp.matmul(dots, value)

    return attention
Exemplo n.º 5
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = fastmath.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Exemplo n.º 6
0
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new per-head activations via scaled dot-product attention.

  This function is the core of the attention mechanism. Given per-head
  ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:

    - computes the scaled dot product of each Q-K pair;
    - applies ``mask`` to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - computes Q-K attention strengths using a per-query softmax of the Q-K dot
      products; and
    - for each query position, combines V vectors according to the Q-K
      attention strengths.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention strengths.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only in ``'train'``
        mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Tuple of (activations, attn_strengths), where activations are new per-head
    activation vectors and attn_strengths is a matrix of per-head attention
    strengths.
  """
  if dropout >= 1.0:
    raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')

  d_feature = queries.shape[-1]

  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    dots = jnp.where(mask,
                     dots,
                     jnp.full_like(dots, -1e9))
  attn_strengths = (
      jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)
    attn_strengths = jnp.where(keep,
                               attn_strengths / (1.0 - dropout),
                               jnp.zeros_like(attn_strengths))
  activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)
  attn_strengths = attn_strengths.astype(jnp.float32)
  return activations, attn_strengths
Exemplo n.º 7
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            if self._sample_all_discrete_actions:
                values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            else:
                values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
Exemplo n.º 8
0
def unflatten_beam_dim(x, batch_size, beam_size):
  """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
    return x
  if batch_size * beam_size == x.shape[0]:
    return x.reshape((batch_size, beam_size) + x.shape[1:])
  else:
    assert x.shape[0] % (batch_size * beam_size) == 0
    res = x.reshape((batch_size, beam_size, -1) + x.shape[1:])
    res = np.swapaxes(res, 1, 2)
    res = res.reshape((-1, beam_size) + res.shape[3:])
    return res
Exemplo n.º 9
0
def flatten_beam_dim(x, batch_size=None):
  """Flattens the first two dimensions of a non-scalar array."""
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
    return x
  if batch_size is not None and x.shape[0] != batch_size:
    assert x.shape[0] % batch_size == 0
    res = x.reshape((batch_size, -1, x.shape[1]) + x.shape[2:])
    res = np.swapaxes(res, 1, 2)
    res = res.reshape(
        (res.shape[0] * res.shape[1] * res.shape[2],) + res.shape[3:])
    return res
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
Exemplo n.º 10
0
 def significance_weights(mask):
   # (repr,) -> (batch, length, repr)
   # significance = [0, 1, 2]
   significance = serializer.significance_map
   assert significance.shape[0] == mask.shape[2]
   # significance = batch_size * [0, 1, 2]
   significance = jnp.repeat(
       significance[np.newaxis, ...], repeats=mask.shape[0], axis=0)
   # significance = batch_size * [0, 1, 2] * mask.shape[1]
   significance = jnp.repeat(
       significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2)
   # significance = batch_size *  mask.shape[1] * [0, 1, 2]
   significance = jnp.swapaxes(significance, 1, 2)
   assert significance.shape == mask.shape
   sig_weights = mask * decay ** significance
   return sig_weights
Exemplo n.º 11
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention strengths
        (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
  d_feature = queries.shape[-1]
  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if fastmath.is_backend(fastmath.Backend.JAX):
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
  # Softmax.
  dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
  out = jnp.matmul(dots, values)
  out = out.astype(jnp.float32)
  dots = dots.astype(jnp.float32)
  return out, dots
Exemplo n.º 12
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head ``queries`` and
      ``keys``,
    - applies ``mask`` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head ``values`` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Exemplo n.º 13
0
def DotProductAttention(query, key, value, mask):
    assert query.shape[-1] == key.shape[-1] == value.shape[-1]

    depth = query.shape[-1]

    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(
        depth)  # Part of dot product formula

    # Apply mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Rest of dot product attention formula
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)

    attention = jnp.matmul(dots, value)

    return attention
Exemplo n.º 14
0
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    # scaling down (Q. K) dot product with square root of depth
    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    # Apply the mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots)
    dots = jnp.exp(dots - logsumexp)

    # Multiply dots by value to get self-attention
    # Use jnp.matmul()
    attention = jnp.matmul(dots, value)

    return attention