def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Creates a funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: upsampling if set to True. Returns: Funnel mask. """ if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert queries_len == 1 mask = jnp.arange( self._max_len) <= (self.state // self._total_kv_pooling) mask = jnp.reshape(mask, (1, 1, 1, self._max_len)) mask = jnp.repeat(mask, batch_size, axis=0) self.state += self._n_raw_tokens_generated return mask if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def forward(self, inputs): q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def fbo(inputs, weights, state, slots, opt_params, rng, step, grads): """FBO of the layer.""" # We need a layer pure_fn but only for inputs and weights. def pure_fn_without_state_and_rng(x, w): return layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the reduced pure fn. activations, vjp_fn, new_state = fastmath.vjp( pure_fn_without_state_and_rng, inputs, weights, has_aux=True) # In the loss layer, set gradients to 1 with the dtype of activations=loss. if grads is None and stats_name is not None: grads = jnp.ones((), dtype=activations.dtype) # The vjp function returns gradients with respect to inputs and weights. grads_inputs, grads_weights = vjp_fn(grads) # For non-trainable layers, return the calculated arguments. if _is_empty_tuple(weights): stats = {} if stats_name is not None: stats[stats_name] = activations return weights, new_state, slots, grads_inputs, stats # In multi-device setting, average gradients from multiple devices. if n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the optimizer. new_weights, new_slots, stats = optimizer.tree_update( step, grads_weights, weights, slots, opt_params) if stats_name is not None: stats[stats_name] = activations return new_weights, new_state, new_slots, grads_inputs, stats
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores tensors with cached values of keys and values, as well as the mask and an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. Mask ensures that attention will only look at keys upto index. During update, we append new_keys and new_values to keys and values at position given by index. We also update mask (which starts as all-0s) to be 1 at the new keys positions. And we increment index by length of new keys. Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, mask, index) Returns: Updated state. """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, mask, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) # Mask is of shape [batch_size, 1 (for heads), length]. new_mask = jnp.ones((mask.shape[0], mask.shape[1], length)) mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2) return (ks, vs, mask, idx + length)
def forward(self, inputs): """Returns attention-computed activations. Args: inputs: A (queries, keys, values) tuple. """ q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def loss_fbo(inputs, weights, state, slots, opt_params, rng, step): """FBO of the final loss layer.""" # We need a loss layer pure_fn but only for inputs and weights. def loss_pure_fn_without_state_and_rng(x, w): return loss_layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the reduced loss pure fn. loss, vjp_fn, new_state = fastmath.vjp( loss_pure_fn_without_state_and_rng, inputs, weights, has_aux=True) # The vjp function returns gradients with respect to inputs and weights. # Since loss is scalar and there are no other layers, run it at 1.0. grads_inputs, grads_weights = vjp_fn(jnp.ones((), dtype=loss.dtype)) # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the loss optimizer, which is the last one since it's the last layer. new_weights, new_slots, stats = self._optimizers[-1].tree_update( step, grads_weights, weights, slots, opt_params) stats['loss'] = loss return new_weights, new_state, new_slots, grads_inputs, stats
def _causal_mask(length): # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) else: return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) if initial.shape[1] > 1: updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0, mask != 0, 1, axis=1) else: updated_mask = initial return updated_mask, x_times_one_minus_f
def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): with fastmath.use_backend(fastmath.Backend.JAX): n_heads = 5 batch, seqlen, d_head = 3, 32, 8 num_weights = 2 n_hashes = 2 d_model = n_heads * d_head layer = efficient_attention.PureLSHSelfAttentionWrapper( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=False, masked=True, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=n_hashes, n_buckets=4, bias=False, pure_lsh_implementation=efficient_attention. PureLSHSelfAttention, mode='train', num_weights=num_weights) rng = jax.random.PRNGKey(0) rng, x_rng = jax.random.split(rng) input_shape = (batch, seqlen, d_model) x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) mask = jnp.ones((batch, seqlen), dtype=jnp.int32) inp = (x, mask) w, s = layer.init(shapes.signature(inp)) o = layer(inp) # Get the actual weights. weights = fastmath.tree_leaves(w) # Assert number of weights is as expected, the extra 1 is for output. self.assertLen(weights, num_weights + 1) # Assert each weight is of the expected shape. for i in range(num_weights + 1): self.assertEqual(weights[i].shape, (d_model, d_model)) # Test that the output and the x's shape match. self.assertEqual(x.shape, o.shape) # Assert state is the shape expected. state = fastmath.tree_leaves(s) self.assertLen(state, 2) # buckets self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) # rngs self.assertEqual(state[1].shape, (batch * n_heads, 2))
def init_weights_and_state(self, input_signature): """Helper to initialize batch norm weights and state.""" axis = self._axis axis = (axis,) if jnp.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) # TODO(jonni): Should beta and gamma match the dtype in the input signature? beta = jnp.zeros(shape, dtype='float32') if self._center else () gamma = jnp.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) running_var = jnp.ones(stats_shape, dtype=jnp.float32) n_batches = jnp.zeros((), dtype=jnp.int64) self.weights = (beta, gamma) self.state = (running_mean, running_var, n_batches)
def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def dot_product_self_attention(q, k, v): """ Masked dot product self attention. Args: q (jax.interpreters.xla.DeviceArray): queries. k (jax.interpreters.xla.DeviceArray): keys. v (jax.interpreters.xla.DeviceArray): values. Returns: jax.interpreters.xla.DeviceArray: masked dot product self attention tensor. """ mask_size = q.shape[-2] # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size) mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)
def init_weights_and_state(self, input_signature): # Usually (B, W, H, C) shape = input_signature.shape num_channels = shape[-1] gamma = jnp.ones((num_channels, ), dtype=jnp.float32) beta = jnp.zeros((num_channels, ), dtype=jnp.float32) epsilon_l = base.EMPTY_WEIGHTS if self._learn_epsilon: epsilon_l = (self._init_learnt_epsilon, ) self.weights = gamma, beta, epsilon_l
def forward(self, inputs): inputs_len = inputs.shape[1] if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert inputs_len == 1 current_token, sequence_length = calc_predict_next_token_index( self.state, self._total_kv_pooling, self._max_len, self._chunk_len, self._chunk_offset) mask = jnp.arange(sequence_length) <= current_token mask = jnp.reshape(mask, (1, sequence_length)) self.state += self._n_raw_tokens_generated return mask if self._chunk_len is not None: return jnp.tril( jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_)) return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
def test_autoregressive_sample_transformer(self): model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, mode='predict') inputs = jnp.ones((1, 3), dtype=jnp.int32) model.init((shapes.signature(inputs), shapes.ShapeDtype((1, 1), dtype=jnp.int32))) s = trainer_lib.autoregressive_sample(model, inputs=inputs, eos_id=-1, max_length=10) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 10)
def dot_product_self_attention(q, k, v): """ Masked dot product self attention. Args: q (jax.interpreters.xla.DeviceArray): queries. k (jax.interpreters.xla.DeviceArray): keys. v (jax.interpreters.xla.DeviceArray): values. Returns: jax.interpreters.xla.DeviceArray: masked dot product self attention tensor. """ # for causal attention: (Q. Kt) + M # mask size should be Lk x Lq mask_size = q.shape[-2] # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size) # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_ # Use jnp.tril() - Lower triangle of an array and jnp.ones() mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)
def init_weights_and_state(self, input_signature): features = input_signature.shape[-1] scale = jnp.ones(features, dtype=input_signature.dtype) bias = jnp.zeros(features, dtype=input_signature.dtype) self.weights = scale, bias
def dot_product_self_attention(q, k, v): mask_size = q.shape[-2] mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)
def _fast_inference_update_state(inputs, state, mask_for_predict=None): """Updates state of a causal attention layer for fast inference. The layer state stores arrays with cached values of keys and values, as well as an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. During update, we append new_keys and new_values to keys and values at position given by index. And we increment index by length of new keys. We also create a mask to be 1 at appropriate positions (causal mask). Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, index) mask_for_predict: mask used for predict mode. This is used only in Terraformer. Returns: Updated state and mask to be used. """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs if mask_for_predict is not None: (state_mask_for_predict, ks, vs, idx) = state else: (ks, vs, idx) = state length = new_k.shape[1] # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) k_length = ks.shape[1] # Mask is of shape [1, q_length, k_length]. # Mask should be true for every pair of (query_token, key_token) such that # index of query_token is equal or larger to index of key_token. mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape( jnp.arange(length) + idx, (1, length, 1))) if mask_for_predict is None: return (ks, vs, idx + length), mask else: state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0, axis=0) state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( state_mask_for_predict != 0, jnp.ones((1, )) != 0, jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0) state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( state_mask_for_predict != 0, jnp.ones((1, )) != 0, idx, axis=0) placeholder = jnp.reshape(state_mask_for_predict != 0, ( 1, 1, mask.shape[2], )) mask = mask * placeholder return (state_mask_for_predict, ks, vs, idx + length), mask
def init_weights_and_state(self, input_signature): self.weights = jnp.zeros((2, 3)) self.state = jnp.ones(input_signature.shape)
def bidirectional_denominator(query_prime, key_prime): all_ones = jnp.ones([query_prime.shape[0]]) ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)