def _update_diagonal(self, g, w, m, v1, v2, opt_params): learning_rate = opt_params['learning_rate'] beta2 = opt_params['second_moment_averaging'] weight_decay = opt_params['weight_decay'] is_beta2_1 = (beta2 == 1).astype(g.dtype) one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), jnp.zeros_like(v1[0])) pg = preconditioner * g if self._graft: v2[0] += g * g preconditioner_graft = jnp.where( v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0])) pg_graft = preconditioner_graft * g pg_norm = jnp.linalg.norm(pg) pg_graft_norm = jnp.linalg.norm(pg_graft) pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) pg = pg + w * weight_decay if self._has_momentum: m, update = self._momentum_update(pg, m, opt_params['momentum']) else: update = pg w = w - (update * learning_rate).astype(w.dtype) return w, (m, v1, v2)
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) # TODO(henrykm): return float(jnp.max(minimum, ret)) would be # better but causes TypeError: 'numpy.float64' object cannot # be interpreted as an integer if ret <= minimum: return minimum return ret
def _update_sketched(self, g, w, m, v1, v2, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] beta2 = opt_params['second_moment_averaging'] weight_decay = opt_params['weight_decay'] shape = w.shape rank = len(shape) reshaped_accumulators = [ jnp.reshape(v1[i], self._expanded_shape(shape, i)) for i in range(rank) ] acc = self._minimum(reshaped_accumulators) is_beta2_1 = (beta2 == 1).astype(g.dtype) one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) acc = beta2 * acc + one_minus_beta2_except1 * g * g preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), jnp.zeros_like(acc)) pg = g * preconditioner if self._graft: v2_acc = self._minimum([ jnp.reshape(v2[i], self._expanded_shape(shape, i)) for i in range(rank) ]) v2_acc = v2_acc + g * g preconditioner_graft = jnp.where(v2_acc > 0.0, 1.0 / (jnp.sqrt(v2_acc) + 1e-16), jnp.zeros_like(v2_acc)) pg_graft = preconditioner_graft * g pg_norm = jnp.linalg.norm(pg) pg_graft_norm = jnp.linalg.norm(pg_graft) pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) pg = pg + w * weight_decay if self._has_momentum: m, update = self._momentum_update(pg, m, momentum) else: update = pg w = w - (learning_rate * update).astype(w.dtype) for i in range(len(v1)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(acc, axis=axes) v1[i] = dim_accumulator if self._graft: for i in range(len(v2)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(v2_acc, axis=axes) v2[i] = dim_accumulator return w, (m, v1, v2)
def Gelu(): r"""Returns a layer that computes the Gaussian Error Linear Unit function. .. math:: f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}})) """ return Fn('Gelu', lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0))))
def DotProductAttention(query, key, value, mask): """Dot product self-attention. Args: query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d) key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d) value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k) Returns: jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k) """ assert query.shape[-1] == key.shape[-1] == value.shape[ -1], "Embedding dimensions of q, k, v aren't all the same" depth = query.shape[-1] # Calculate scaled query key dot product according to formula above dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth) if mask is not None: # The 'None' in this line does not need to be replaced dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax formula implementation logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) dots = jnp.exp(dots - logsumexp) attention = jnp.matmul(dots, value) return attention
def forward(self, x): scale, bias = self.weights mean = jnp.mean(x, axis=-1, keepdims=True) sub = x - mean variance = jnp.mean(sub * sub, axis=-1, keepdims=True) norm_inputs = sub / jnp.sqrt(variance + self._epsilon) return norm_inputs * scale + bias
def log_prob(self, inputs, point): point = point.reshape(inputs.shape[:-1] + (-1, )) return ( # L2 term. -jnp.sum((point - inputs)**2, axis=-1) / (2 * self._std**2) - # Normalizing constant. ((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi))) * np.prod(self._shape)))
def update(self, step, grads, weights, avg_sq_grad, opt_params): del step lr = opt_params['learning_rate'] gamma = opt_params['gamma'] eps = opt_params['eps'] avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) weights = weights - (lr * grads / (jnp.sqrt(avg_sq_grad) + eps)).astype( weights.dtype) return weights, avg_sq_grad
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): """Computes new per-head activations via scaled dot-product attention. This function is the core of the attention mechanism. Given per-head ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - computes the scaled dot product of each Q-K pair; - applies ``mask`` to screen out positions that come from padding tokens (indicated by 0 value); - [in ``'train'`` mode] applies dropout to Q-K dot products; - computes Q-K attention strengths using a per-query softmax of the Q-K dot products; and - for each query position, combines V vectors according to the Q-K attention strengths. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention strengths. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Tuple of (activations, attn_strengths), where activations are new per-head activation vectors and attn_strengths is a matrix of per-head attention strengths. """ if dropout >= 1.0: raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.') d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) attn_strengths = ( jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))) if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) attn_strengths = jnp.where(keep, attn_strengths / (1.0 - dropout), jnp.zeros_like(attn_strengths)) activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) attn_strengths = attn_strengths.astype(jnp.float32) return activations, attn_strengths
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) return float(ret)
def update(self, step, grads, weights, slots, opt_params): m, v = slots learning_rate = opt_params['learning_rate'] weight_decay_rate = opt_params['weight_decay_rate'] b1 = opt_params['b1'] b2 = opt_params['b2'] eps = opt_params['eps'] m = (1 - b1) * grads + b1 * m # First moment estimate. v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. mhat = m / (1 - b1 ** (step + 1)) # Bias correction. vhat = v / (1 - b2 ** (step + 1)) new_weights = ((1 - weight_decay_rate) * weights - ( learning_rate * mhat / (jnp.sqrt(vhat) + eps))).astype(weights.dtype) return new_weights, (m, v)
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head ``queries`` and ``keys``, - applies ``mask`` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head ``values`` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def _calc_attn_scores(q, k): ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k) bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd) dots = (ac + bd) / jnp.sqrt(d_feature) dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) return dots
def DotProductAttention(query, key, value, mask): assert query.shape[-1] == key.shape[-1] == value.shape[-1] depth = query.shape[-1] dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt( depth) # Part of dot product formula # Apply mask if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Rest of dot product attention formula logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) dots = jnp.exp(dots - logsumexp) attention = jnp.matmul(dots, value) return attention
def DotProductAttention(query, key, value, mask): """Dot product self-attention. Args: query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d) key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d) value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k) Returns: jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k) """ assert query.shape[-1] == key.shape[-1] == value.shape[ -1], "Embedding dimensions of q, k, v aren't all the same" # scaling down (Q. K) dot product with square root of depth depth = query.shape[-1] # Calculate scaled query key dot product according to formula above dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth) # Apply the mask if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax formula implementation # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots) dots = jnp.exp(dots - logsumexp) # Multiply dots by value to get self-attention # Use jnp.matmul() attention = jnp.matmul(dots, value) return attention