Exemple #1
0
def multigaussian_loss(preds, targets, ngauss=1):  # pylint: disable=invalid-name
    """Compute mixture of gaussians loss."""
    ndims = targets.shape[-1]
    logits = preds[:, :ngauss]
    mus = preds[:, ngauss:ngauss * (ndims + 1)]
    sigmas = preds[:, ngauss(ndims + 1):]
    sigmas = sigmas * sigmas + 1e-6  # Make positive.
    loglogits = logits - math.logsumexp(logits, axis=-1, keepdims=True)
    mus = jnp.reshape(mus, [-1, ngauss, ndims])
    sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims])
    targets = jnp.reshape(targets, [-1, 1, ndims])
    glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas)
    return math.logsumexp(loglogits + glogprobs, axis=-1)
Exemple #2
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Exemple #3
0
def Softmax(axis=-1):
  """Returns a layer that applies softmax along one tensor axis.

  `Softmax` acts on a group of values and normalizes them to look like a set
  of probability values. (Probability values must be non-negative, and as a
  set must sum to 1.)

  Args:
    axis: Axis along which values are grouped for computing softmax.
  """
  return Fn('Softmax',
            lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Exemple #4
0
def LogSoftmax(axis=-1):
  """Returns a layer that applies log softmax along one tensor axis.

  `LogSoftmax` acts on a group of values and normalizes them to look like a set
  of log probability values. (Probability values must be non-negative, and as
  a set must sum to 1. A group of log probability values can be seen as the
  natural logarithm function applied to a set of probability values.)

  Args:
    axis: Axis along which values are grouped for computing log softmax.
  """
  return Fn('LogSoftmax',
            lambda x: x - math.logsumexp(x, axis, keepdims=True))
Exemple #5
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
Exemple #6
0
def Softmax(axis=-1):
    """Layer that applies softmax: exponentiate and normalize along given axis."""
    return Fn('Softmax',
              lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Exemple #7
0
def LogSoftmax(axis=-1):
    """Layer that applies log softmax: log-normalize along the given axis."""
    return Fn('LogSoftmax',
              lambda x: x - math.logsumexp(x, axis, keepdims=True))
Exemple #8
0
def Softmax(x, axis=-1, **unused_kwargs):
    """Apply softmax to x: exponentiate and normalize along the given axis."""
    return np.exp(x - math.logsumexp(x, axis, keepdims=True))
Exemple #9
0
def LogSoftmax(x, axis=-1, **unused_kwargs):
    """Apply log softmax to x: log-normalize along the given axis."""
    return x - math.logsumexp(x, axis, keepdims=True)