def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def update(self, step, grads, params, avg_sq_grad, opt_params): del step (learning_rate, gamma, eps) = opt_params avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) params = params - (learning_rate * grads / (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype) return params, avg_sq_grad
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def dot_product_attention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate - keep probability mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) dots = stax.softmax(dots, axis=-1) if dropout is not None and mode == 'train': keep = random.bernoulli(rng, dropout, dots.shape) dots = np.where(keep, dots / dropout, 0) out = np.matmul(dots, value) return out
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 z = (x - mean) / np.sqrt(var + epsilon) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z
def apply_fun(params, inputs, **kwargs): del kwargs (scale, bias) = params mean = np.mean(inputs, axis=-1, keepdims=True) variance = np.mean((inputs - mean)**2, axis=-1, keepdims=True) norm_inputs = (inputs - mean) / np.sqrt(variance + epsilon) return norm_inputs * scale + bias
def binned_attn(sq, sk, sv): """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = chunk_rank3(sq_t) bkv_t = chunk_rank3(skv_t) bq = chunk_rank4(sq) bk = chunk_rank4(sk) bv = chunk_rank4(sv) dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]), np.float32) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) bo = np.matmul(dots, bv) so = unchunk_rank4(bo) return so
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), 0) out = np.matmul(dots, value) return out
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__ # with each element of x, resulting in an onp.ndarray with dtype `object`. z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: ret = gamma * z + beta elif center: ret = z + beta elif scale: ret = gamma * z else: ret = z assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is ' 'not the same as the input (%s). Batch norm ' 'should not change the dtype' % (ret.dtype, x.dtype)) return ret
def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier out_slice = np.matmul(dots, value) return out_slice
def update(self, step, grads, params, slots, opt_params): updates = [] (learning_rate, beta1, decay_rate, clipping_threshold, weight_decay_rate, epsilon1, epsilon2) = opt_params decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend return new_params, updates
def _update_diagonal(self, step, g, x, m, v): v[0] += g * g preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_g = preconditioner * g m = (1 - self._momentum) * preconditioned_g + self._momentum * m x = x - self.step_size(step) * m return x, (m, v)
def _update_diagonal(self, grads, params, m, v, opt_params): (learning_rate, momentum) = opt_params v[0] += grads * grads preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m params = params - (learning_rate * m).astype(params.dtype) return params, (m, v)
def update(self, i, g, x, state): m, v = state b1, b2, eps = self._b1, self._b2, self._eps m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * (g**2) + b2 * v # Second moment estimate. mhat = m / (1 - b1**(i + 1)) # Bias correction. vhat = v / (1 - b2**(i + 1)) x = x - self.step_size(i) * mhat / (np.sqrt(vhat) + eps) return x, (m, v)
def update(self, step, grads, params, avg_sq_grad, opt_params): del step learning_rate = opt_params["learning_rate"] gamma = opt_params["gamma"] eps = opt_params["eps"] avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) params = params - (learning_rate * grads / (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype) return params, avg_sq_grad
def update(self, step, grads, params, slots, opt_params): m, v = slots learning_rate, weight_decay_rate, b1, b2, eps = opt_params m = (1 - b1) * grads + b1 * m # First moment estimate. v = (1 - b2) * (grads**2) + b2 * v # Second moment estimate. mhat = m / (1 - b1**(step + 1)) # Bias correction. vhat = v / (1 - b2**(step + 1)) params = (1 - weight_decay_rate) * params - ( learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(params.dtype) return params, (m, v)
def update(self, i, g, x, state): updates = [] decay_rate = self._decay_rate(i) update_scale = self._step_size(i) if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2) mixing_rate = 1.0 - decay_rate g_sqr = g * g + self._epsilon1 if self._factored and len(x.shape) >= 2: v_row = state.pop(0) v_col = state.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = ( g * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = state.pop(0) new_v = decay_rate * v + mixing_rate * g_sqr updates.append(new_v) y = g * (new_v)**-0.5 if self._clipping_threshold is not None: clipping_denom = ( np.maximum(1.0, np.sqrt(np.mean(y * y)) / self._clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._beta1: m = state.pop(0) new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_x = x - subtrahend return new_x, updates
def call(self, x, params, state, **unused_kwargs): """Layer construction function for a batch normalization layer.""" running_mean, running_var, num_batches = state if self._mode == 'train': mean = np.mean(x, self._axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, self._axis, keepdims=True) var = m1 - mean**2 num_batches = num_batches + 1 if self._momentum is None: # A simple average over all batches seen so far exponential_average_factor = 1.0 / num_batches else: exponential_average_factor = self._momentum def average(factor, new, old): return (factor * new + (1 - factor) * old).astype(old.dtype) running_mean = average(exponential_average_factor, mean, running_mean) running_var = average(exponential_average_factor, var, running_var) state = (running_mean, running_var, num_batches) else: mean = running_mean var = running_var z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype( x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in self._axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if self._center and self._scale: output = gamma * z + beta elif self._center: output = z + beta elif self._scale: output = gamma * z else: output = z assert output.dtype == x.dtype, ( 'The dtype of the output (%s) of batch ' 'norm is not the same as the input (%s). ' 'Batch norm should not change the dtype' % (output.dtype, x.dtype)) return output, state
def learning_rate(step): # pylint: disable=invalid-name """Step to learning rate function.""" ret = 1.0 for name in factors: if name == "constant": ret *= constant elif name == "linear_warmup": ret *= np.minimum(1.0, step / warmup_steps) elif name == "rsqrt_decay": ret /= np.sqrt(np.maximum(step, warmup_steps)) else: raise ValueError("Unknown factor %s." % name) return ret
def apply_fun(params, x, **kwargs): beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True) z = (x - mean) / np.sqrt(var + epsilon) if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 v1 = 47.9167 layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 0) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1) onp.testing.assert_allclose(state[1], v1, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6) inp2 = inp1 * 2 + 3 m2 = m1 * 2 + 3 v2 = v1 * 4 m12 = (m1 + m2) / 2 v12 = (v1 + v2) / 2 out, state = layer(inp2, params, state) onp.testing.assert_allclose(state[0], m12) onp.testing.assert_allclose(state[1], v12, rtol=1e-6) self.assertEqual(state[2], 2) onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps), rtol=1e-6) layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval") inp3 = inp1 * 5 + 7 out, state_unchanged = layer(inp3, params, state) for i in range(3): onp.testing.assert_allclose(state_unchanged[i], state[i]) onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps), rtol=1e-6)
def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so
def forward_slice(query_slice, q_loop_idx, key, value): """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) out_slice = np.matmul(dots, value) return out_slice
def learning_rate(step): # pylint: disable=invalid-name """Step to learning rate function.""" ret = 1.0 for name in factors: if name == "constant": ret *= constant elif name == "linear_warmup": ret *= np.minimum(1.0, step / warmup_steps) elif name == "rsqrt_decay": ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == "decay_every": ret *= (decay_factor**(step // steps_per_decay)) else: raise ValueError("Unknown factor %s." % name) ret = np.asarray(ret, dtype=np.float32) return {"learning_rate": ret}
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" (learning_rate, momentum) = opt_params shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" if self._share_qk: key = self.make_unit_length(key) dots = np.matmul( query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e5 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize. dots /= dots_sum # Re-normalize. out_slice = np.matmul(dots, value) return out_slice
def _update_sketched(self, step, g, x, m, v): """Update for higher-rank parameters.""" shape = x.shape rank = len(shape) reshaped_accumulators = [ np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank) ] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += g * g accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = g * accumulator_inv_sqrt m = (1.0 - self._momentum) * preconditioned_gradient + self._momentum * m x = x - self.step_size(step) * m for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return x, (m, v)
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params, kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, _, v = inputs seqlen = qk.shape[-2] # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops? qk = np.tile(qk, (self.n_hashes, 1, 1)) v = np.tile(v, (self.n_hashes, 1, 1)) # bins are n_hashes*n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_hashes*n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int( (self.n_buckets_per_bin * self.n_bins + 1) * seqlen ) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) # TODO(kitaev): why does jax flag integer indices as differentiable? # If we don't call stop_gradient here, custom gradients below won't work # because the primitive functions close over "differentiable" variables. sjoint_t = jax.lax.stop_gradient(sjoint_t) undo_sort = jax.lax.stop_gradient(undo_sort) # The backward pass of gather is in general a scatter operation, but we know # we're dealing with permutations so we use gather for the backward pass # too. This custom gradient should be about 2x faster than having jax infer # one that uses scatter ops instead. def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2) def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2) @jax.custom_transforms def permute(vecs): return permute_impl(vecs) def permute_vjp(vecs): out_vecs = permute_impl(vecs) def vjpfun(grad): return (unpermute_impl(grad), ) return out_vecs, vjpfun @jax.custom_transforms def unpermute(vecs): return unpermute_impl(vecs) def unpermute_vjp(vecs): out_vecs = unpermute_impl(vecs) def vjpfun(grad): return (permute_impl(grad), ) return out_vecs, vjpfun jax.defvjp_all(permute, permute_vjp) jax.defvjp_all(unpermute, unpermute_vjp) sqk = permute(qk) sv = permute(v) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Sum to re-normalize. dots_logsumexp += np.log(dots_sum) # Add it to the weight. dots /= dots_sum # Re-normalize. bo = np.matmul(dots, bv) so = unchunk_vectors(bo) slogits = unchunk_vectors(dots_logsumexp) o = unpermute(so) logits = unpermute(slogits) o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == inputs[2].shape return out, state
def make_unit_length(self, x, epsilon=1e-6): variance = np.mean(x**2, axis=-1, keepdims=True) norm_inputs = x / np.sqrt(variance + epsilon) return norm_inputs