def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `Attention`, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return Fn('SplitIntoHeads', f) def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f) return cb.Serial( cb.Branch( [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], ), DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode), _merge_heads(), core.Dense(d_feature), )
def RelativeAttentionLMLayer(d_feature, total_kv_pooling, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, chunk_len=None, chunk_offset=None, mode='train'): """Returns a layer that maps (q, k, v) to (activations). Same as standard Relative attention layer but additionally based on sizes of queries and keys prepares a mask that masks out the future. Masking the future is the concept primarily used for Language Modelling. Args: d_feature: Depth/dimensionality of feature embedding. total_kv_pooling: Accumulated pool size of keys/values used at this layer. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention mode: One of `'train'`, `'eval'`, or `'predict'`. """ attention = RelativeAttentionLayer( d_feature, total_kv_pooling, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, mode=mode) mask_layer = AttentionMaskLayer( total_kv_pooling=total_kv_pooling, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) return cb.Serial( cb.Branch( None, mask_layer, # vecs, mask ), attention, # vecs, mask cb.Select([0], n_in=2), # vecs )
def LSTM(n_units): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1), cb.Select([0], n_in=2) # Drop RNN state. )
def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=activation_fns.Sigmoid, candidate_nonlinearity=activation_fns.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), base.Fn(lambda x: x + sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), base.Fn(lambda x: x + sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Branch(memory_transform, gate_block, candidate_block), cb.Gate(), )
def RelativeAttentionLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Branch( PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), cb.Select([0]), cb.Select([1])), cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), context_bias_layer, location_bias_layer, RelativeAttention( # pylint: disable=no-value-for-parameter separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def LSTM(n_units, mode='train', return_state=False, initial_state=False): """LSTM running on axis 1. Args: n_units: `n_units` for the `LSTMCell`. mode: if 'predict' then we save the previous state for one-by-one inference. return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. Returns: A LSTM layer. """ if not initial_state: zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter if return_state: return cb.Serial(cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Branch([], zero_state), # fill state RNN with zero. cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[]) else: if return_state: return cb.Serial(cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. name=f'LSTM_{n_units}', sublayers_to_print=[])
def GRU(n_units): """GRU running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=1) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(GRUCell(n_units=n_units), axis=1), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to GRU and don't print sublayers. name=f'GRU_{n_units}', sublayers_to_print=[] )
def LSTM(n_units, mode='train'): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[] )
def SRU(n_units, activation=None, mode='train'): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. mode: if 'predict' then we save the previous state for one-by-one inference Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), ScanSRUCell(mode=mode), cb.Select([0], n_in=2), # act(c), r, x activation if activation is not None else [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def ConfigurableAttention( q_layer, k_layer, v_layer, final_layer, # pylint: disable=invalid-name qkv_attention_layer, n_heads=1): return cb.Serial( cb.Branch( [q_layer, SplitIntoHeads(n_heads)], [k_layer, SplitIntoHeads(n_heads)], [v_layer, SplitIntoHeads(n_heads)], ), qkv_attention_layer, MergeHeads(n_heads), final_layer)
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)))
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling, resampling_fn): """Attention resampling.""" attention = RelativeAttentionLMLayer(d_model, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) resampling = resampling_fn(shorten_factor, d_model, mode=mode) def _Dropout(): return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ LayerNorm(), # h cb.Branch(cb.Serial( resampling, LayerNorm(), ), None), # h', h cb.Serial( # pylint: disable=g-long-ternary cb.Select([0, 2, 1, 2]), cb.Add(), ) if is_upsampling else [], cb.Residual( cb.Select([0, 1, 1]), # h', h, h attention, _Dropout(), ), cb.Residual( LayerNorm(), feed_forward, _Dropout(), ), ]
def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed causal attention. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: attention dropout mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result. """ assert d_feature % n_heads == 0 d_head = d_feature // n_heads def compute_attention_heads(x): batch_size = x.shape[0] seqlen = x.shape[1] # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head x = jnp.transpose(x, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head return jnp.reshape(x, (-1, seqlen, d_head)) ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads) def compute_attention_output(x): seqlen = x.shape[1] x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)) x = jnp.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head return jnp.reshape(x, (-1, seqlen, n_heads * d_head)) return cb.Serial( cb.Branch( [core.Dense(d_feature), ComputeAttentionHeads], [core.Dense(d_feature), ComputeAttentionHeads], [core.Dense(d_feature), ComputeAttentionHeads], ), DotProductCausalAttention(dropout=dropout, mode=mode), Fn('ComputeAttentionOutput', compute_attention_output), core.Dense(d_feature) )
def ConfigurableAttention( q_layer, k_layer, v_layer, final_layer, # pylint: disable=invalid-name qkv_attention_layer, n_heads=1): """Returns a configured multi-head self-attention layer. A :py:class:`ConfigurableAttention` layer acts similarly to :py:class:`Attention` layers, but with configurable components. It - makes three copies of incoming activations and uses ``q_layer``, ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; - uses ``qkv_attention_layer`` to compute per-head attention, similar to :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`; - concatenates and fuses resulting per-head vectors into activations matching original input activation shapes; and - applies a final layer, ``final_layer``, mapping activations to activations (with shape matching the original input activations). Args: q_layer: Layer that maps input activations to per-head query activations. k_layer: Layer that maps input activations to per-head key activations. v_layer: Layer that maps input activations to per-head value activations. final_layer: After main multi-head computation and rejoining of heads, layer that maps activations to activations (with shape matching the original input activations). qkv_attention_layer: Layer the does the core multi-head self-attention computation. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. """ return cb.Serial( cb.Branch( [q_layer, SplitIntoHeads(n_heads)], [k_layer, SplitIntoHeads(n_heads)], [v_layer, SplitIntoHeads(n_heads)], ), qkv_attention_layer, MergeHeads(n_heads), final_layer)
def CausalFavor( d_feature, n_heads=1, dropout=0.0, # pylint: disable=invalid-name numerical_stabilizer=0.001, precision=None, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head causal attention, but using FAVOR fast attention as in the following paper: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. precision: passed to np.einsum to define arithmetic precision. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API # TODO(lukaszkaiser): make an API for split/merge heads in core layers, # and use it here so we don't duplicate these functions. if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return base.Fn('SplitIntoHeads', f) def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return base.Fn('MergeHeads', f) def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv p += np.einsum('...m,...d->...md', k, v, precision=precision) x_slice = np.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) return w, (p, query_prime, key_prime, value) def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct): del init_prefix_sum_value def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= np.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) p, qs, ks, vs = pqkv _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return qs_ct, ks_ct, vs_ct def favor_numerator(init_prefix_sum_value, precision, query_prime, key_prime, value): w, _ = favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value) return w favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd, favor_numerator_bwd, nondiff_argnums=(0, 1)) def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime): def body(p, qk): q, k = qk p += k x = np.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) return r, (query_prime, key_prime, p) def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct): del init_prefix_sum_value def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) qs, ks, p = qkp _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (qs_ct, ks_ct) def favor_denominator(init_prefix_sum_value, precision, query_prime, key_prime): r, _ = favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime) return r favor_denominator = fastmath.custom_vjp(favor_denominator, favor_denominator_fwd, favor_denominator_bwd, nondiff_argnums=(0, 1)) favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) def relu(x): return np.where(x <= 0, np.zeros_like(x), x) def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = np.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0), np.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0)) w = np.moveaxis(w, 0, 1) r = np.moveaxis(r, 0, 1) r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer) r = np.reciprocal(r) r = np.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention return cb.Serial( cb.Branch( [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], ), base.Fn('FAVOR', favor), _merge_heads(), core.Dense(d_feature), )
def test_branch_name(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) # pylint: disable=no-value-for-parameter self.assertIn('Branch', str(layer))
def test_branch_one_layer(self): layer = cb.Branch(divide_by(0.5)) input_signature = ShapeDtype((3, 2)) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_branch_add_div(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2))) expected_shape = ((3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_branch_noop_dup(self): layer = cb.Branch([], cb.Dup()) input_signature = ShapeDtype((3, 2)) expected_shape = ((3, 2), (3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_branch_name(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) self.assertIn('Branch', str(layer))
def test_branch_op_not_defined(self): with self.assertRaises(AttributeError): cb.Branch([], [])
def RelativeAttentionLayer(d_feature, total_kv_pooling, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, chunk_len=None, chunk_offset=None, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. total_kv_pooling: Accumulated pool size of keys/values used at this layer. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention mode: One of `'train'`, `'eval'`, or `'predict'`. """ pos_emb = PositionalEmbeddings( d_feature, total_kv_pooling, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) attention = RelativeAttention( # pylint: disable=no-value-for-parameter total_kv_pooling=total_kv_pooling, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, mode=mode), assert d_feature % n_heads == 0 d_head = d_feature // n_heads context_bias_layer = core.Weights( init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head)) location_bias_layer = core.Weights( init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head)) return cb.Serial( cb.Branch( cb.Serial(pos_emb, core.Dense(d_feature)), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), cb.Select([1]) # mask ), context_bias_layer, location_bias_layer, attention, core.Dense(d_feature), )
def ModularCausalAttention( d_feature, n_heads=1, dropout=0.0, # pylint: disable=invalid-name max_inference_length=2048, n_modules=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses LocallyConnectedDense instead of Dense layer for computing K/Q/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. n_modules: Number of modules used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads @assert_shape('bld->hlx') def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((batch_size * n_heads, seq_len, d_head)) return x return tl.Fn('SplitIntoHeads', f) @assert_shape('hlx->bld') def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head * n_heads)) return x return tl.Fn('MergeHeads', f) @assert_shape('...a->...b') def ProcessingLayer(): # pylint: disable=invalid-name if n_modules == 1: return tl.Dense(d_feature) else: assert d_feature % n_modules == 0 return LocallyConnectedDense(n_modules, d_feature // n_modules) return cb.Serial( cb.Branch( [ProcessingLayer(), _split_into_heads()], [ProcessingLayer(), _split_into_heads()], [ProcessingLayer(), _split_into_heads()], ), tl.DotProductCausalAttention(dropout=dropout, max_inference_length=max_inference_length, mode=mode), _merge_heads(), ProcessingLayer())
def CreateAttentionMaskLayer(): """Creates attention mask layer. Returns a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates positional embeddings for causal relative attention calculations. Takes as input q, k, v and appends proper mask in the end. Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol. Returns: an attention mask layer. """ def calculate_mask(queries, keys): batch_size = queries.shape[0] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling) def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) return cb.Branch( cb.Select([0]), cb.Select([1]), cb.Select([2]), cb.Fn('create attention mask layer', calculate_mask, n_out=1))