def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name """Compute mixture of gaussians loss.""" ndims = targets.shape[-1] logits = preds[:, :ngauss] mus = preds[:, ngauss:ngauss * (ndims + 1)] sigmas = preds[:, ngauss(ndims + 1):] sigmas = sigmas * sigmas + 1e-6 # Make positive. loglogits = logits - math.logsumexp(logits, axis=-1, keepdims=True) mus = jnp.reshape(mus, [-1, ngauss, ndims]) sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims]) targets = jnp.reshape(targets, [-1, 1, ndims]) glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) return math.logsumexp(loglogits + glogprobs, axis=-1)
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def Softmax(axis=-1): """Returns a layer that applies softmax along one tensor axis. `Softmax` acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be non-negative, and as a set must sum to 1.) Args: axis: Axis along which values are grouped for computing softmax. """ return Fn('Softmax', lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
def LogSoftmax(axis=-1): """Returns a layer that applies log softmax along one tensor axis. `LogSoftmax` acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be non-negative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.) Args: axis: Axis along which values are grouped for computing log softmax. """ return Fn('LogSoftmax', lambda x: x - math.logsumexp(x, axis, keepdims=True))
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out
def Softmax(axis=-1): """Layer that applies softmax: exponentiate and normalize along given axis.""" return Fn('Softmax', lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
def LogSoftmax(axis=-1): """Layer that applies log softmax: log-normalize along the given axis.""" return Fn('LogSoftmax', lambda x: x - math.logsumexp(x, axis, keepdims=True))
def Softmax(x, axis=-1, **unused_kwargs): """Apply softmax to x: exponentiate and normalize along the given axis.""" return np.exp(x - math.logsumexp(x, axis, keepdims=True))
def LogSoftmax(x, axis=-1, **unused_kwargs): """Apply log softmax to x: log-normalize along the given axis.""" return x - math.logsumexp(x, axis, keepdims=True)