示例#1
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
示例#2
0
def label_smoothed_loss(logpred, target, size, padding_idx=0, smoothing=0.0):
    """Returns a label-smoothing loss-criterion function."""
    confidence = 1.0 - smoothing
    zerosmoothed = smoothing / (size - 2)
    delta = confidence - zerosmoothed
    assert logpred.shape[1] == size
    truedist = (np.full_like(logpred, zerosmoothed) +
                delta * slax.one_hot(target, size))
    truedist *= (1 - (np.arange(size) == padding_idx))
    truedist *= (1 - (target == padding_idx))[:, np.newaxis]
    return kl_div(logpred, truedist, eps=1e-6)