def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def label_smoothed_loss(logpred, target, size, padding_idx=0, smoothing=0.0): """Returns a label-smoothing loss-criterion function.""" confidence = 1.0 - smoothing zerosmoothed = smoothing / (size - 2) delta = confidence - zerosmoothed assert logpred.shape[1] == size truedist = (np.full_like(logpred, zerosmoothed) + delta * slax.one_hot(target, size)) truedist *= (1 - (np.arange(size) == padding_idx)) truedist *= (1 - (target == padding_idx))[:, np.newaxis] return kl_div(logpred, truedist, eps=1e-6)