def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct)
def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct)
def DotProductAttention(queries, keys, values, pos_emb, context_bias, location_bias, mask, separate_cls, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. pos_emb: Per-head activations representing positional embeddings. context_bias: Global context bias from Transformer XL's attention. location_bias: Global location bias from Transformer XL's attention. mask: Mask that distinguishes positions with real content vs. padding. separate_cls: True/False if we separate_cls in calculations. dropout: Probabilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys) bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) if separate_cls: # Masking out location part of attention for cls token bd = bd.at[:, :, :, 0].set(0) bd = bd.at[:, :, 0, :].set(0) dots = (ac + bd) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def EinsumDense(d_input, d_output, use_bias): """Returns a reimplementation of Dense layer, using einsum. While this is an equivalent of a Dense layer, it seems to be faster when used in decoding if used with bias (see decoding_timing_test.py ). This layer can be removed when we understand better the reason for the difference in decoding speed. Args: d_input: Dimensionality of the input tensor. d_output: Dimensionality of the output tensor. use_bias: Whether to use bias. """ layers = [ tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), tl.Fn( 'EinsumDense', ( lambda kernel, embeds: # pylint: disable=g-long-lambda jnp.einsum('xd,...d->...x', kernel, embeds))) ] if use_bias: layers.extend([ tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add() ]) return tl.Serial(layers)
def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def _calc_attn_scores(q, k): ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k) bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd) dots = (ac + bd) / jnp.sqrt(d_feature) dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) return dots
def MultiplicativeSparseDense(sparsity, d_input, d_output=None, use_bias=True, use_bfloat16=False): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It multiplies each dimension of the input tensor by a scalar specific to each dimension and each module separately; then it applies Dense(d_output/sparsity) to each module. Compared to standard dense layer, MultiplicativeSparseDense uses less parameters while still being able to express many interesting functions (for example a permutation). Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_input: Dimensionality of input tensor. d_output: Dimensionality of output tensor; by default equal to d_input. use_bias: Whether to use bias. use_bfloat16: Whether to use bfloat16 for weights. """ assert d_output % sparsity == 0 d_module = d_output // sparsity layers = [ # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_input], use_bfloat16=use_bfloat16), # Weight below is dense kernel, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module], use_bfloat16=use_bfloat16), # To save memory the per-head preprocessing and multiplying by the # kernel is done in the same einsum. tl.Fn( 'AttentionEinsum', ( lambda kernel, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))), MergeLastTwoAxes(), ] if use_bias: layers.extend([ # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_output], use_bfloat16=use_bfloat16), tl.Add(), ]) return tl.Serial(layers)
def rotate(x): """Rotate function.""" _, l, d = x.shape inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) positions = jnp.arange(l) freqs = jnp.einsum('i,j->ij', positions, inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.cos(emb) sin = jnp.sin(emb) def mul(vecs, pos_emb): return jnp.einsum('bld,ld->bld', vecs, pos_emb) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) return mul(x, cos) + mul(rotate_half(x), sin)
def Sinusoidal_Embeddings(positions, d_feature): """Sinusoidal Embeddings. Computes out of 1-D integer absolute position vector the sinusoidal embeddings defined like in paper Attention is all you need (2017). Embeddings are shaped (positions, d_feature). Args: positions: a one-dimensional array of positions. d_feature: the number of sin-cos features. Returns: Positional embeddings. """ inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def MultiplicativeModularSparseDense(sparsity, d_feature): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It is a combination of multiplicative dense and locally connected dense layers. Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_feature: Dimensionality of input and output tensor. """ assert d_feature % sparsity == 0 d_module = d_feature // sparsity return tl.Serial( # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature]), # Weight below is a kernel of multiplicative dense, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), # Weight below is a kernel of modular dense. tl.Weights( functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), [sparsity, d_module, d_module]), # To save memory the per-head preprocessing and multiplying by # kernels is done in a single einsum. tl.Fn( 'SparseDenseEinsum', ( lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), tl.Add(), )
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `filters` value, and the second to last dimension is shrinked if 'VALID' padding is used with kernel_size bigger than one. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights else: w = self.weights linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w) # TODO(jaszczur): this could be run after padding for better efficiency if self._kernel_size == 1: # With kernel size 1 we don't have to split or shift anything. linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) else: # We computed a result for every "pixel", but each direction from the # receptive field (there are 'self._kernel_size' such directions) must be # shifted by a different amount. The easiest way to do it is to split # the tensor to 'self._kernel_size' smaller tensors, shift each one # appropriately, and then sum them together. split_shifting_linear_results = jnp.split( linear_results_before_shifting, self._kernel_size, axis=-2) for i in range(self._kernel_size): # Each tensor has to be shifted a different amount. if self._padding == 'WRAP': # We can shift by padding and cutting. With 'wrap' padding we # essentially have a torus. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding, mode='wrap') split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] elif self._padding == 'SAME': # We can shift by padding and cutting. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding) split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] # TODO(jaszczur): improve efficiency by not padding things to cut elif self._padding == 'VALID': # We don't need to shift - just cut the leftmost and rightmost values. cut_left = (self._kernel_size - 1) - i cut_right = split_shifting_linear_results[i].shape[-3] - i split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., cut_left:cut_right, :, :] else: raise ValueError(f'Invalid padding {self._padding}') # After shifting. shifted_linear_results = jnp.concatenate( split_shifting_linear_results, axis=-2) linear_result = jnp.sum(shifted_linear_results, axis=-2) if self._use_bias: return linear_result + b else: return linear_result
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def mul(vecs, pos_emb): return jnp.einsum('bld,ld->bld', vecs, pos_emb)
def body(p, qk): q, k = qk p += k x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x
def test_lsh_and_pure_lsh_self_attention_equivalence(self): # Given the same weight matrices and random numbers, do these produce the # same output. with fastmath.use_backend(fastmath.Backend.JAX): n_heads = 4 d_head = 4 d_model = n_heads * d_head pure_lsh_layer = efficient_attention.PureLSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, bias=False, mode='train') lsh_layer = efficient_attention.LSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, mode='train') batch, seqlen = 3, 32 input_shape = (batch, seqlen, d_model) x = jax.random.uniform(jax.random.PRNGKey(0), input_shape, dtype=jnp.float32) lsh_layer_input = x call_rng = jax.random.PRNGKey(42) lsh_layer_weights, lsh_layer_state = lsh_layer.init( shapes.signature(lsh_layer_input)) lsh_layer.rng = call_rng lsh_layer_output = lsh_layer(lsh_layer_input) # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), # (n_heads, d_head, d_model) # Abbreviated as - hmn, hmn, hnm w_qk, w_v, w_o = lsh_layer_weights qk = jnp.einsum('blm,hmn->bhln', x, w_qk) qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) v = jnp.einsum('blm,hmn->bhln', x, w_v) v = v.reshape((-1, v.shape[2], v.shape[3])) pure_lsh_layer_input = (qk, v) _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) pure_lsh_layer.rng = call_rng pure_lsh_layer.state = lsh_layer_state pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) # b*h,l,n pure_lsh_layer_output = pure_lsh_layer_output.reshape( (batch, -1) + pure_lsh_layer_output.shape[1:]) pure_lsh_layer_output_projected = (jnp.einsum( 'bhld,hdm->blm', pure_lsh_layer_output, w_o)) diff = pure_lsh_layer_output_projected - lsh_layer_output avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) self.assertLess(avg_diff, 1e-5)
def body(p, qkv): (q, k, v) = qkv p += jnp.einsum('...m,...d->...md', k, v, precision=precision) x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice
def bidirectional_denominator(query_prime, key_prime): all_ones = jnp.ones([query_prime.shape[0]]) ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)
def bidirectional_numerator(query_prime, key_prime, value): kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) return jnp.einsum('lbm,bmd->lbd', query_prime, kvs)
def reverse(self, x, weights=(), state=(), new_state=(), rng=None): del state, new_state, rng shape = x.shape x = x.reshape(shape[:-1] + (self._get_multiplier(x), -1)) t_x = jnp.einsum('...ab->...ba', x) # transpose return t_x.reshape(shape)
def forward(self, x): shape = x.shape x = x.reshape(shape[:-1] + (-1, self._get_multiplier(x))) t_x = jnp.einsum('...ab->...ba', x) # transpose return t_x.reshape(shape)