def PositionsVectors(queries, keys): is_funnel_layer = queries.shape != keys.shape keys_len, queries_len = keys.shape[1], queries.shape[1] current_pooling_ratio = keys_len / queries_len # Special case of upsampling if is_funnel_layer and current_pooling_ratio < 1: # We should not be doing standard upsampling when we use separate_cls # Cls token is being used for classification assert not separate_cls assert (total_kv_pooling * keys_len) % queries_len == 0 multiplier = ((total_kv_pooling * keys_len) // queries_len) positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier else: positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling if is_funnel_layer and separate_cls: # For pool_size 2 without separating cls we have got # [0][1][2][3][4][5][6][7] -> [01][23][45][67] # With separating cls we have got # [0][1][2][3][4][5][6][7] -> [0][12][34][56] # First group always will always consist of one token after pooling # instead of (pool_size) tokens. We need to add proper offset so # that our shift later on in calculating attention works properly cls_offset = (current_pooling_ratio - 1) * total_kv_pooling positions = positions + cls_offset return positions
def _sincos(self, start, length, d_feature): """Create the sin-cos tensor of shape [1, length, d_feature].""" position = jnp.arange(0, length)[:, None] + start div_term = jnp.exp( jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature)) sin = jnp.sin(position * div_term) cos = jnp.cos(position * div_term) pe = jnp.concatenate([sin, cos], axis=1) return pe[None, :, :] # [1, length, d_feature]
def PositionsVectors(self, n_tokens): if self._mode == 'predict': current_token, sequence_length = calc_predict_next_token_index( self.state, self._total_kv_pooling, self._max_len, self._chunk_len, self._chunk_offset) positions = jnp.arange(0, sequence_length, 1.0) - current_token self.state = self.state + self._n_raw_tokens_generated return positions sequence_length = self._chunk_len if self._chunk_len is not None else n_tokens offset = sequence_length - 1 # offset to be compatible with predict mode positions = jnp.arange(sequence_length) - offset return positions
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = jnp.arange(n_categories) if fastmath.is_backend(fastmath.Backend.JAX): # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
def forward(self, x): rng = self.rng batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = fastmath.random.randint( rng1, (batch_size,), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = fastmath.random.randint( rng2, (batch_size,), 0, max_pos-length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = self.weights[bn][i] pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = fastmath.random.randint( rng3, (batch_size,), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return x + res
def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def forward(self, inputs): state = self.state depth = inputs.shape[-1] if self._mode == 'predict': emb = self._get_embeddings(t=state) emb = emb[:, jnp.newaxis, :] state = state + 1 else: input_len = inputs.shape[-2] emb = self._get_embeddings( t=jnp.arange(input_len, dtype=jnp.int32)) # Leave batch axis as 1 for broadcasting: emb = emb[jnp.newaxis, :, :] emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, )) # Replace the last num_features channels of input. inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1) if inputs.shape[-1] > depth: logging.warning('dropping feature(s): %d down to %d', inputs.shape[-1], depth) inputs = inputs[..., -depth:] assert inputs.shape[-1] == depth, inputs.shape self.state = state return inputs
def PositionsVectors(queries, keys): assert not separate_cls keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) if funnel_factor == 1: offset = keys_len - 1 positions = (jnp.arange(keys_len) - offset) * total_kv_pooling else: if is_upsampling: positions = jnp.arange(-queries_len + 1, queries_len, 1.0) else: positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling return positions
def rotate(x): """Rotate function.""" _, l, d = x.shape inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) positions = jnp.arange(l) freqs = jnp.einsum('i,j->ij', positions, inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.cos(emb) sin = jnp.sin(emb) def mul(vecs, pos_emb): return jnp.einsum('bld,ld->bld', vecs, pos_emb) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) return mul(x, cos) + mul(rotate_half(x), sin)
def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Creates a funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: upsampling if set to True. Returns: Funnel mask. """ if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert queries_len == 1 mask = jnp.arange( self._max_len) <= (self.state // self._total_kv_pooling) mask = jnp.reshape(mask, (1, 1, 1, self._max_len)) mask = jnp.repeat(mask, batch_size, axis=0) self.state += self._n_raw_tokens_generated return mask if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores arrays with cached values of keys and values, as well as an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. During update, we append new_keys and new_values to keys and values at position given by index. And we increment index by length of new keys. We also create a mask to be 1 at appropriate positions (causal mask). Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, index) Returns: Updated state and mask to be used. """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) k_length = ks.shape[1] # Mask is of shape [1, q_length, k_length]. # Mask should be true for every pair of (query_token, key_token) such that # index of query_token is equal or larger to index of key_token. mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1))) return (ks, vs, idx + length), mask
def forward(self, x): rng = self.rng base_weights, start_vec = self.weights batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] # In training we'll randomize starts for better generalization. # We use the trainable start_vec to compensate and give model a way # to learn what is the starting position in a sequence. if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = fastmath.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = fastmath.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] if self._mode == 'predict': positions += self.state res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = base_weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = fastmath.random.randint( rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) # Sum embeddings from all bases. # Add start_vec to the first position only to mark it as starting. res0 = res[:, 0, :][:, None, :] start_pos = res0 + start_vec if self._mode == 'predict': start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) self.state += length # Add input length to state. res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) return x + res
def Sinusoidal_Embeddings(positions, d_feature): """Sinusoidal Embeddings. Computes out of 1-D integer absolute position vector the sinusoidal embeddings defined like in paper Attention is all you need (2017). Embeddings are shaped (positions, d_feature). Args: positions: a one-dimensional array of positions. d_feature: the number of sin-cos features. Returns: Positional embeddings. """ inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" if fastmath.backend_name() != 'jax': raise ValueError(f'JAX backend is required in predict mode, but found ' f'backend ({fastmath.backend_nameO()}).') for x in inputs: if x.shape[1] != 1: raise ValueError(f'In predict mode, input sequence must have length 1, ' f'instead has length {x.shape[1]}.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, seq_indices) = state batch_indices = jnp.arange(ks.shape[0]) ks = jax.ops.index_update( ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :]) vs = jax.ops.index_update( vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :]) mask = jax.ops.index_update( mask, jax.ops.index[batch_indices, :, seq_indices], 1) return (ks, vs, mask, seq_indices + 1)
def forward(self, inputs): inputs_len = inputs.shape[1] if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert inputs_len == 1 current_token, sequence_length = calc_predict_next_token_index( self.state, self._total_kv_pooling, self._max_len, self._chunk_len, self._chunk_offset) mask = jnp.arange(sequence_length) <= current_token mask = jnp.reshape(mask, (1, sequence_length)) self.state += self._n_raw_tokens_generated return mask if self._chunk_len is not None: return jnp.tril( jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_)) return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): """Splits a key into a stream of random keys. This uses the little-endian counter mode. Args: key: uint32[2] the key to split lo: the range to start extracting from hi: the range to stop extracting from Returns: keys: uint32[hi - lo, 2] the split keys """ if not (key.shape == (2, ) and key.dtype == jnp.uint32): raise ValueError('key must be uint32[2]') if not hi < 2**32: # You shouldn't really be using more than half the key size anyways. raise NotImplementedError('only 32-bit sizes are supported') # Create a 64-bit counter: i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) i_hi = jnp.zeros_like(i_lo) i = jnp.stack([i_lo, i_hi], axis=-1) return threefry_2x32_prf(key, i)
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = jnp.arange(n_categories) return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
def top_k(x, k): """Select the top k slices from the last dimension.""" bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape) sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs) # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization return sorted_vals[..., -k:], sorted_idxs[..., -k:]
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def bit_sequence(inputs): seq_length = inputs.shape[1] n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 return jnp.arange(0, n_bits)