def CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like ``Attention``, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') return ConfigurableAttention(core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), n_heads=n_heads, qkv_attention_layer=DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def QKVLayer(): """Function returning the Q, K and V layer.""" if use_dconv: return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) else: return core.Dense(d_feature)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form q, k, v, mask. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like :py:class:`Attention`, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking. Args: d_feature: Last/innermost dimension of activations in the input to and output from this layer. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in ``'train'`` mode. max_inference_length: Maximum sequence length allowed in non-training modes. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') return ConfigurableAttention( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), n_heads=n_heads, qkv_attention_layer=DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `Attention`, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return Fn('SplitIntoHeads', f) def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f) return cb.Serial( cb.Branch( [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], ), DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode), _merge_heads(), core.Dense(d_feature), )
def RelativeAttentionLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Branch( PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), cb.Select([0]), cb.Select([1])), cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), context_bias_layer, location_bias_layer, RelativeAttention( # pylint: disable=no-value-for-parameter separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, activation): # We copy the ff block function because we cannot import it from models return [ core.Dense(d_ff), activation(), core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), core.Dense(d_model), ]
def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) rng1, rng2 = backend.random.split(backend.random.get_prng(0), 2) params1, _ = model1.initialize_once((1, 32), onp.float32, rng1) params2, _ = model2.initialize_once((1, 32), onp.float32, rng2) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def test_dense_weight_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) input_signature = ShapeDtype((1, 32)) weights1, _ = model1.init(input_signature) weights2, _ = model2.init(input_signature) # The first weights have 2 kernels of size (32, 32). self.assertEqual((32, 32), weights1[0][0].shape) self.assertEqual((32, 32), weights1[1][0].shape) # The second weights have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), weights2[0][0].shape) self.assertEqual((), weights2[1])
def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) input_signature = ShapeDtype((1, 32)) params1, _ = model1.initialize_once(input_signature) params2, _ = model2.initialize_once(input_signature) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, mask) to (activations, mask). Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or 'eval'. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention(n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed causal attention. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: attention dropout mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result. """ assert d_feature % n_heads == 0 d_head = d_feature // n_heads def compute_attention_heads(x): batch_size = x.shape[0] seqlen = x.shape[1] # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head x = jnp.transpose(x, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head return jnp.reshape(x, (-1, seqlen, d_head)) ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads) def compute_attention_output(x): seqlen = x.shape[1] x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)) x = jnp.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head return jnp.reshape(x, (-1, seqlen, n_heads * d_head)) return cb.Serial( cb.Branch( [core.Dense(d_feature), ComputeAttentionHeads], [core.Dense(d_feature), ComputeAttentionHeads], [core.Dense(d_feature), ComputeAttentionHeads], ), DotProductCausalAttention(dropout=dropout, mode=mode), Fn('ComputeAttentionOutput', compute_attention_output), core.Dense(d_feature) )
def _SparsifiableDense(layer_sparsity): if layer_sparsity is None: return core.Dense(d_feature) elif layer_sparsity == 'noop': return cb.Serial() # No-op layer. else: d_module = d_feature // layer_sparsity return cb.Serial( sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3) )
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See `Attention` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train', **kwargs): del args, kwargs return cb.Serial( core.Dense(shorten_factor * d_model), core.Dropout(rate=dropout, mode=mode), core.Fn( 'ProlongBack', lambda x: jnp.reshape( # pylint: disable=g-long-lambda # Prolong back. # pylint: disable=g-long-lambda x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1) )
def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train', **kwargs): del args, kwargs return cb.Serial( core.Fn( 'Shorten', lambda x: jnp.reshape( # pylint: disable=g-long-lambda # Shorten -- move to depth. # pylint: disable=g-long-lambda x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), core.Dense(d_model), core.Dropout(rate=dropout, mode=mode) )
def test_set_rng_serial_recurse_two_levels(self): dense_00 = core.Dense(2) dense_01 = core.Dense(2) dense_10 = core.Dense(2) dense_11 = core.Dense(2) layer = cb.Serial( cb.Serial(dense_00, dense_01), cb.Serial(dense_10, dense_11), ) input_signature = ShapeDtype((1, 2)) _, _ = layer.init(input_signature) weights = layer.weights dense_00_w, dense_00_b = weights[0][0] dense_01_w, dense_01_b = weights[0][1] dense_10_w, dense_10_b = weights[1][0] dense_11_w, dense_11_b = weights[1][1] # Setting rng's recursively during init should yield differing weights. self.assertFalse(np.array_equal(dense_00_w, dense_01_w)) self.assertFalse(np.array_equal(dense_00_b, dense_01_b)) self.assertFalse(np.array_equal(dense_10_w, dense_11_w)) self.assertFalse(np.array_equal(dense_10_b, dense_11_b))
def GRUCell(n_units): """Builds a traditional GRU cell with dense internal transformations. Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 Args: n_units: Number of hidden units. Returns: A Stax model representing a traditional GRU RNN cell. """ return GeneralGRUCell(candidate_transform=lambda: core.Dense(n_units), memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh)
def SRU(n_units, activation=None, mode='train'): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. mode: if 'predict' then we save the previous state for one-by-one inference Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), ScanSRUCell(mode=mode), cb.Select([0], n_in=2), # act(c), r, x activation if activation is not None else [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)))
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
def test_weights_serial(self): model = cb.Serial(core.Dense(4), core.Dense(5), core.Dense(7)) self.assertIsInstance(model.weights, tuple) self.assertLen(model.weights, 3)
def _CacheableDense(): if cache_KV_in_predict and mode == 'predict': return cb.Cache(core.Dense(d_feature)) else: return core.Dense(d_feature)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See ``Attention`` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V tensors in predict mode. q_sparsity: Sparsity with which to process queries. If None, Dense is used. If 'noop' then no processing is used. result_sparsity: Sparsity with which to process result of the attention. If None, Dense is used. If 'noop' then no processing is used. """ k_processor = core.Dense(d_feature) v_processor = core.Dense(d_feature) if cache_KV_in_predict and mode == 'predict': k_processor = cb.Cache(k_processor) v_processor = cb.Cache(v_processor) if q_sparsity is None: q_processor = core.Dense(d_feature) elif q_sparsity == 'noop': q_processor = cb.Serial() else: d_module = d_feature // q_sparsity q_processor = cb.Serial( sparsity.MultiplicativeSparseDense(q_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(q_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) if result_sparsity is None: result_processor = core.Dense(d_feature) elif result_sparsity == 'noop': result_processor = cb.Serial() else: d_module = d_feature // result_sparsity result_processor = cb.Serial( sparsity.MultiplicativeSparseDense(result_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(result_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) return cb.Serial( cb.Parallel( q_processor, k_processor, v_processor, ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), result_processor)
def test_state_parallel(self): model = cb.Parallel(core.Dense(3), core.Dense(5)) self.assertIsInstance(model.state, tuple) self.assertLen(model.state, 2)
def CausalFavor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, precision=None, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head causal attention, but using FAVOR fast attention as in the following paper: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. precision: passed to jnp.einsum to define arithmetic precision. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv p += jnp.einsum('...m,...d->...md', k, v, precision=precision) x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) return w, (precision, p, query_prime, key_prime, value) def favor_numerator_bwd(pqkv, w_ct): precision, p, qs, ks, vs = pqkv def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return (None, None, qs_ct, ks_ct, vs_ct) def favor_numerator(init_prefix_sum_value, precision, query_prime, key_prime, value): w, _ = favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value) return w favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd, favor_numerator_bwd) def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime): def body(p, qk): q, k = qk p += k x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) return r, (precision, query_prime, key_prime, p) def favor_denominator_bwd(qkp, r_ct): precision, qs, ks, p = qkp def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) _, (qs_ct, ks_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (None, None, qs_ct, ks_ct) def favor_denominator(init_prefix_sum_value, precision, query_prime, key_prime): r, _ = favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime) return r favor_denominator = fastmath.custom_vjp(favor_denominator, favor_denominator_fwd, favor_denominator_bwd) favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) def relu(x): return jnp.where(x <= 0, jnp.zeros_like(x), x) def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention return tl.ConfigurableAttention(core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), n_heads=n_heads, qkv_attention_layer=base.Fn( 'CausalFAVOR', favor))
def test_state_serial(self): model = cb.Serial(core.Dense(4), core.Dense(5), core.Dense(7)) self.assertIsInstance(model.state, tuple) self.assertLen(model.state, 3)
def test_weights_parallel(self): model = cb.Parallel(core.Dense(3), core.Dense(5)) self.assertIsInstance(model.weights, tuple) self.assertLen(model.weights, 2)
def CausalFavor( d_feature, n_heads=1, dropout=0.0, # pylint: disable=invalid-name numerical_stabilizer=0.001, precision=None, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head causal attention, but using FAVOR fast attention as in the following paper: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. precision: passed to np.einsum to define arithmetic precision. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API # TODO(lukaszkaiser): make an API for split/merge heads in core layers, # and use it here so we don't duplicate these functions. if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return base.Fn('SplitIntoHeads', f) def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return base.Fn('MergeHeads', f) def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv p += np.einsum('...m,...d->...md', k, v, precision=precision) x_slice = np.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) return w, (p, query_prime, key_prime, value) def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct): del init_prefix_sum_value def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= np.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) p, qs, ks, vs = pqkv _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return qs_ct, ks_ct, vs_ct def favor_numerator(init_prefix_sum_value, precision, query_prime, key_prime, value): w, _ = favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value) return w favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd, favor_numerator_bwd, nondiff_argnums=(0, 1)) def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime): def body(p, qk): q, k = qk p += k x = np.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) return r, (query_prime, key_prime, p) def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct): del init_prefix_sum_value def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) qs, ks, p = qkp _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (qs_ct, ks_ct) def favor_denominator(init_prefix_sum_value, precision, query_prime, key_prime): r, _ = favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime) return r favor_denominator = fastmath.custom_vjp(favor_denominator, favor_denominator_fwd, favor_denominator_bwd, nondiff_argnums=(0, 1)) favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) def relu(x): return np.where(x <= 0, np.zeros_like(x), x) def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = np.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0), np.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0)) w = np.moveaxis(w, 0, 1) r = np.moveaxis(r, 0, 1) r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer) r = np.reciprocal(r) r = np.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention return cb.Serial( cb.Branch( [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], ), base.Fn('FAVOR', favor), _merge_heads(), core.Dense(d_feature), )