Exemple #1
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     if sampled_all_discrete:  # Actions were sampled uniformly; weight them.
         mask *= jnp.exp(old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Exemple #2
0
        def PPOJointLoss(x, **unused_kwargs):
            """Definition of the Proximal Policy Optimization loss."""
            dist_inputs, values, returns, actions, old_log_probs, mask = x
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            new_log_probs = self._policy_dist.log_prob(dist_inputs, actions)

            advantages = returns - values
            l2_value_loss = jnp.sum(advantages**2) * self._value_loss_coeff

            # Old log probs have an undesirable extra dimension which we remove here
            old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                                      dtype=jnp.float32)
            new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1))

            # The ratio between new_probs and old_probs expressed
            # using log_probs and exponentaion
            probs_ratio = jnp.exp(new_log_probs - old_log_probs)
            unclipped_objective = probs_ratio * advantages
            clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon,
                                         1 + self._epsilon) * advantages
            ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

            entropy_loss = self._policy_dist.entropy(new_log_probs) *\
                self._entropy_coeff

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Exemple #3
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Exemple #4
0
def Softmax5Branches(x_list, **unused_kwargs):
    """Softmax qs.

  The input xs is a list of weights and embedded queries of the form
  w_1 ... w_n q_1 ... q_n. The q_1 ... q_n will be kept, result appended.

  Args:
    x_list: the input weights and embeddings.

  Returns:
    the weighted average of q_1 ... q_n according to softmax(w).
  """
    n_branches = 5
    softmax_activations = x_list[:n_branches]
    max_sa = softmax_activations[0]
    for x in softmax_activations:
        max_sa = np.maximum(max_sa, x)
    softmax_activations = [x - max_sa for x in softmax_activations]
    softmax_activations = [np.exp(x) for x in softmax_activations]
    sum_sa = sum(softmax_activations)
    softmax_activations = [x / sum_sa for x in softmax_activations]
    res = sum([
        x_list[i + n_branches] * softmax_activations[i]
        for i in range(n_branches)
    ])
    return res
Exemple #5
0
 def _aggregate_values(self, values, aggregate_max, act_log_probs):
     if self._q_value:
         if aggregate_max:
             values = jnp.max(values, axis=1)
         elif self._sample_all_discrete_actions:
             values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
         else:
             values = jnp.mean(values, axis=1)
     return np.array(values)  # Move the values to CPU.
Exemple #6
0
    def _calc_adv_weights(self, adv, valid_mask):
        weights = jnp.exp(adv / self._temperature)

        valid_weights = weights[valid_mask]
        weights_mean = jnp.mean(valid_weights)
        weights_min = jnp.min(valid_weights)
        weights_max = jnp.max(valid_weights)

        weights = jnp.minimum(weights, self._weight_clip)
        return weights, weights_mean, weights_min, weights_max
Exemple #7
0
def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun):
    """Probability Ratio from the PPO algorithm."""
    # Old log probs have an undesirable extra dimension which we remove here
    old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                              dtype=jnp.float32)
    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    # The ratio between new_probs and old_probs expressed
    # using log_probs and exponentaion
    probs_ratio = jnp.exp(new_log_probs - old_log_probs)
    return probs_ratio
Exemple #8
0
 def AWRLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
     logps, values, returns, actions = x
     advantage = returns - values
     l2_value_loss = jnp.sum(
         (returns - values)**2) * self._value_loss_coeff
     awr_weights = jnp.minimum(jnp.exp(advantage / self._beta),
                               self._w_max)
     log_loss = -1.0 * self._policy_dist.log_prob(logps, actions)
     policy_loss = jnp.sum(
         log_loss * awr_weights) / jnp.sum(awr_weights)
     return policy_loss + l2_value_loss
Exemple #9
0
def Softmax(axis=-1):
  """Returns a layer that applies softmax along one tensor axis.

  `Softmax` acts on a group of values and normalizes them to look like a set
  of probability values. (Probability values must be non-negative, and as a
  set must sum to 1.)

  Args:
    axis: Axis along which values are grouped for computing softmax.
  """
  return Fn('Softmax',
            lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Exemple #10
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
Exemple #11
0
    def f(new_log_probs, advantages, old_log_probs, mask):
        # Old log probs have an undesirable extra dimension which we remove here
        old_log_probs = old_log_probs.squeeze(axis=-1)

        # The ratio between new_probs and old_probs expressed
        # using log_probs and exponentaion
        probs_ratio = jnp.exp(new_log_probs - old_log_probs)
        unclipped_objective = probs_ratio * advantages
        clipped_objective = jnp.clip(probs_ratio, 1 - epsilon,
                                     1 + epsilon) * advantages
        ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
        return -np.sum(ppo_objective * mask) / np.sum(mask)
Exemple #12
0
def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun):
    """Probability Ratio from the PPO algorithm."""
    # dist_inputs of the shape float32[128,1,18]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == old_log_probs.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')
    # The ratio between new_probs and old_probs expressed
    # using log_probs and exponentaion
    probs_ratio = jnp.exp(new_log_probs - old_log_probs)
    return probs_ratio
Exemple #13
0
def PPOLoss(x, epsilon, **unused_kwargs):
    """Definition of the Proximal Policy Optimization loss."""
    (new_log_probs, advantages, old_log_probs, mask) = x
    # Old log probs have an undesirable extra dimension which we remove here
    old_log_probs = old_log_probs.squeeze(axis=-1)

    # The ratio between new_probs and old_probs expressed
    # using log_probs and exponentaion
    probs_ratio = jnp.exp(new_log_probs - old_log_probs)
    unclipped_objective = probs_ratio * advantages
    clipped_objective = jnp.clip(probs_ratio, 1 - epsilon,
                                 1 + epsilon) * advantages
    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    return -np.sum(ppo_objective * mask) / np.sum(mask)
Exemple #14
0
        def ProbsRatioMean(x, **unused_kwargs):
            """Probability Ratio Mean from the PPO algorithm."""
            dist_inputs, _, _, actions, old_log_probs = x
            new_log_probs = self._policy_dist.log_prob(dist_inputs, actions)

            # Old log probs have an undesirable extra dimension which we remove here
            old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                                      dtype=jnp.float32)
            new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1))

            # The ratio between new_probs and old_probs expressed
            # using log_probs and exponentaion
            probs_ratio = jnp.exp(new_log_probs - old_log_probs)
            return jnp.mean(probs_ratio)
Exemple #15
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
Exemple #16
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
Exemple #17
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            if self._sample_all_discrete_actions:
                values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            else:
                values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
Exemple #18
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
Exemple #19
0
def Softmax(axis=-1):
    """Layer that applies softmax: exponentiate and normalize along given axis."""
    return Fn('Softmax',
              lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Exemple #20
0
def Exp():
    return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Exemple #21
0
    def forward_unbatched(self, x, *, weights, state, update_state):
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_rng = state
            rng = jax.random.fold_in(old_rng, 0)
            hash_rng = jax.random.fold_in(rng, 1)
            buckets = self.hash_vectors(q, hash_rng)
            state = (buckets, rng)
        else:
            buckets, rng = state

        rng = jax.random.fold_in(rng, 2)

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(mask_self_attention,
                                    causal=self.causal,
                                    exclude_self=True)
        q_info = st
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            dropout=self.attention_dropout,
            rng=rng,
        )

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        return out, state
Exemple #22
0
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = (k is None)

    if q_info is None:
        q_info = np.arange(q.shape[-2])

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2])

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        kv_info = q_info
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix)
        keep_prob = jax.lax.tie_in(dots, 1.0 - dropout)
        keep = jax.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1, ))
    return out, dots_logsumexp
Exemple #23
0
def Softmax(x, axis=-1, **unused_kwargs):
    """Apply softmax to x: exponentiate and normalize along the given axis."""
    return np.exp(x - math.logsumexp(x, axis, keepdims=True))
Exemple #24
0
def Exp(x, **unused_kwargs):
    return np.exp(x)
Exemple #25
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Exemple #26
0
def awr_weights(advantages, beta):
    return jnp.exp(advantages / beta)
Exemple #27
0
 def entropy(self, log_probs):
     probs = jnp.exp(log_probs)
     return -jnp.sum(probs * log_probs, axis=-1)
Exemple #28
0
def Exp():
  """Returns a layer that computes the element-wise exponential of a tensor."""
  return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Exemple #29
0
 def entropy(self, log_probs):
     del log_probs  # would be helpful if self._std was learnable
     return jnp.exp(self._std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e)
Exemple #30
0
def AWRLoss(x, beta, w_max, **unused_kwargs):
    """Definition of the Advantage Weighted Regression (AWR) loss."""
    (log_probs, advantages, _) = x
    weights = jnp.minimum(jnp.exp(advantages / beta), w_max)
    return -(log_probs * weights).mean()