예제 #1
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
예제 #2
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
예제 #3
0
 def LossInput(dist_inputs, actions, advantages, old_dist_inputs, mask):  # pylint: disable=invalid-name
   """Calculates action log probabilities and normalizes advantages."""
   del old_dist_inputs
   advantages = self._preprocess_advantages(advantages)
   dist_inputs = jnp.broadcast_to(
       dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape
   )
   log_probs = self._policy_dist.log_prob(dist_inputs, actions)
   # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
   advantages = jnp.swapaxes(advantages, 0, 1)
   mask = jnp.swapaxes(mask, 0, 1)
   return (log_probs, advantages, log_probs, mask)
예제 #4
0
 def gather_fn(x):
   """Gather slices for a single tensor."""
   if x.ndim == 0:  # ignore scalars (e.g. cache index)
     return x
   elif x.shape[0] != batch_size:
     assert x.shape[0] % batch_size == 0
     res = x.reshape((batch_size, -1,) + x.shape[1:])
     res = np.swapaxes(res, 1, 2)
     res = res[batch_indices, beam_indices]
     res = np.swapaxes(res, 1, 2)
     res = res.reshape((-1,) + res.shape[2:])
     return res
   else:
     return x[batch_indices, beam_indices]
예제 #5
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
예제 #6
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
예제 #7
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            # Reweight: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
예제 #8
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
예제 #9
0
def unflatten_beam_dim(x, batch_size, beam_size):
  """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
    return x
  if batch_size * beam_size == x.shape[0]:
    return x.reshape((batch_size, beam_size) + x.shape[1:])
  else:
    assert x.shape[0] % (batch_size * beam_size) == 0
    res = x.reshape((batch_size, beam_size, -1) + x.shape[1:])
    res = np.swapaxes(res, 1, 2)
    res = res.reshape((-1, beam_size) + res.shape[3:])
    return res
예제 #10
0
def flatten_beam_dim(x, batch_size=None):
  """Flattens the first two dimensions of a non-scalar array."""
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
    return x
  if batch_size is not None and x.shape[0] != batch_size:
    assert x.shape[0] % batch_size == 0
    res = x.reshape((batch_size, -1, x.shape[1]) + x.shape[2:])
    res = np.swapaxes(res, 1, 2)
    res = res.reshape(
        (res.shape[0] * res.shape[1] * res.shape[2],) + res.shape[3:])
    return res
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
예제 #11
0
  def policy_batches_stream(self):
    """Use the RLTask self._task to create inputs to the policy model."""
    # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
    for np_trajectory in self._task.trajectory_batch_stream(
        self._policy_batch_size,
        epochs=self._replay_epochs,
        max_slice_length=self._max_slice_length,
        include_final_state=False,
    ):
      (q_values, actions) = self._run_value_model(
          np_trajectory.observations, np_trajectory.dist_inputs
      )
      # TODO(pkozakowski): Try max here.
      values = jnp.mean(q_values, axis=0)

      if len(values.shape) != 2:
        raise ValueError('Values are expected to have shape ' +
                         '[batch_size, length], got: %s' % str(values.shape))
      if values.shape[0] != self._policy_batch_size:
        raise ValueError('Values first dimension should = policy batch size, ' +
                         '%d != %d' %(values.shape[0], self._policy_batch_size))

      # q_values shape: (n_samples, batch_size, length)
      # values shape: (batch_size, length)
      # Computing advantages by broadcasting over n_samples.
      advantages = q_values - values
      mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape)

      shapes.assert_shape_equals(
          advantages, (self._q_value_n_samples,) + values.shape
      )
      shapes.assert_same_shape(mask, advantages)

      # Swapping the n_samples and batch_size axes, so the input is split
      # between accelerators along the batch_size axis.
      advantages = jnp.swapaxes(advantages, 0, 1)
      mask = jnp.swapaxes(mask, 0, 1)

      yield (np_trajectory.observations, actions, advantages, mask, mask)
예제 #12
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
예제 #13
0
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = (k is None)

    if q_info is None:
        q_info = np.arange(q.shape[-2])

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2])

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        kv_info = q_info
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix)
        keep_prob = jax.lax.tie_in(dots, 1.0 - dropout)
        keep = jax.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1, ))
    return out, dots_logsumexp