def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. res = np.concatenate([x, pos], axis=2) return res
def ChunkedPositionalEncoding(x, params, **unused_kwargs): """Implements bare positional encoding.""" if not isinstance(x, (list, tuple)): # non-chunked inputs symbol_size = np.shape(x)[1] return x + params[:, :symbol_size, :] # Chunked case: apply to all chunks selecting as much as needed. offset = 0 results = [] for chunk in x: symbol_size = np.shape(chunk)[1] results.append(chunk + params[:, offset:offset + symbol_size, :]) offset += symbol_size return results
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), 0) out = np.matmul(dots, value) return out
def PureAttention(x, params, n_heads=1, dropout=0.0, mode='train', **kwargs): """Pure transformer-style multi-headed attention. Args: x: inputs (q, k, v, mask) params: parameters (none) n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' **kwargs: other arguments including the rng Returns: Pure Multi-headed attention result, and the mask. """ del params rng = kwargs.get('rng', None) q, k, v, mask = x d_feature = q.shape[-1] assert d_feature % n_heads == 0 d_head = d_feature // n_heads nbatch = np.shape(q)[0] # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head def SplitHeads(x): return np.transpose( np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3)) # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature def JoinHeads(x): # pylint: disable=invalid-name return np.reshape( np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head)) # Split heads, dot-product attention, rejoin heads. res = JoinHeads( DotProductAttention( SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, dropout=dropout, mode=mode, rng=rng)) return res, mask # Keep the mask.
def apply_fun(params, inputs, **kwargs): # pylint: disable=missing-docstring del params rng = kwargs.get('rng', None) q, k, v, mask = inputs assert feature_depth % num_heads == 0 head_depth = feature_depth // num_heads nbatch = np.shape(q)[0] # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth def split_heads(x): return np.transpose( np.reshape(x, (nbatch, -1, num_heads, head_depth)), (0, 2, 1, 3)) # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth def join_heads(x): return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, num_heads * head_depth)) # Split heads, dot-product attention, rejoin heads. return join_heads( dot_product_attention(split_heads(q), split_heads(k), split_heads(v), mask, dropout=dropout, mode=mode, rng=rng))
def dot_product_attention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate - keep probability mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) dots = stax.softmax(dots, axis=-1) if dropout is not None and mode == 'train': keep = random.bernoulli(rng, dropout, dots.shape) dots = np.where(keep, dots / dropout, 0) out = np.matmul(dots, value) return out
def SplitHeads(x, params, n_heads=1, **kwargs): del params, kwargs d_model = x.shape[-1] assert d_model % n_heads == 0 d_head = d_model // n_heads n_batch = np.shape(x)[0] # n_batch, seqlen, d_model --> n_batch, n_heads, seqlen, d_head return np.transpose(np.reshape(x, (n_batch, -1, n_heads, d_head)), (0, 2, 1, 3))
def forward(self, inputs, params=(), state=(), **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] return (x + params[:, :symbol_size, :], state) else: assert self._mode == 'predict' # Fast inference: return consectutive elements of the encoding sequence, # storing the index in state. return (inputs + np.expand_dims(params[:, state, :], 1), state + 1)
def PureMultiHeadedAttention(params, x, feature_depth=None, num_heads=8, dropout=0.0, mode='train', **kwargs): """Pure transformer-style multi-headed attention. Args: params: parameters (none) x: inputs (q, k, v, mask) feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' **kwargs: other arguments including the rng Returns: Pure Multi-headed attention layer. (No Dense transforms on input.) """ del params rng = kwargs.get('rng', None) q, k, v, mask = x assert feature_depth % num_heads == 0 head_depth = feature_depth // num_heads nbatch = np.shape(q)[0] # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, num_heads, head_depth)), (0, 2, 1, 3)) # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, num_heads * head_depth)) # Split heads, dot-product attention, rejoin heads. return JoinHeads( DotProductAttention(SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, dropout=dropout, mode=mode, rng=rng))
def PositionalEncoding(x, params, **unused_kwargs): """Implements bare positional encoding.""" symbol_size = np.shape(x)[1] return x + params[:, :symbol_size, :]
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None # jax uses the term cotangent (ct) to refer to gradient signals, and # vector-Jacobian product (vjp) for back-propagation through a layer. def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name # Capture q_loop_idx to avoid calculated gradients wrt. it. def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name return forward_slice(query_slice, q_loop_idx, key, value) output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx, query_slice, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[-2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): # pylint: disable=invalid-name q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): # pylint: disable=invalid-name """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=-2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=-2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) key_ct_accum = key_ct_accum + partial_ct[1] value_ct_accum = value_ct_accum + partial_ct[2] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=-2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def forward_and_vjp(self, inputs, ct, params=(), **kwargs): # This is the core of the memory-efficient attention implementation, where # we use the jax.lax.while_loop primitive to compute attention for a small # set of query positions at a time. Note how in the backwards pass, we # compute both the forward direction (to recover the previous layer's # activations) and the backward direction simultaneously. This allows us to # only use a single loop, where the inner portion of the loop does a slice # of the forward+backward joint computation. Unfortunately we have had to # introduce a large number of wrapper classes (including # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole # purpose of connecting this implementation of forward_and_vjp with the core # backprop implementation. query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None def make_mask(N, M, k): x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): output_slice, vjpfun = jax.vjp(forward_slice, query_slice, q_loop_idx, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=2) # ignore partial_ct[1], which is wrt the loop idx key_ct_accum = key_ct_accum + partial_ct[2] value_ct_accum = value_ct_accum + partial_ct[3] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def JoinHeads(x, params, **kwargs): del params, kwargs n_batch = np.shape(x)[0] seqlen = np.shape(x)[2] # n_batch, n_heads, seqlen, d_head --> n_batch, seqlen, d_model return np.reshape(np.transpose(x, (0, 2, 1, 3)), (n_batch, seqlen, -1))
def apply_fun(params, inputs, **kwargs): del kwargs pe = params symbol_size = np.shape(inputs)[1] return inputs + pe[:, :symbol_size]