def test_select_computes_n_in(self): layer = cb.Select([0, 0]) self.assertEqual(layer.n_in, 1) layer = cb.Select([1, 0]) self.assertEqual(layer.n_in, 2) layer = cb.Select([2]) self.assertEqual(layer.n_in, 3)
def RelativeAttentionLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Branch( PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), cb.Select([0]), cb.Select([1])), cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), context_bias_layer, location_bias_layer, RelativeAttention( # pylint: disable=no-value-for-parameter separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). This layer type represents one pass of multi-head self-attention, best known for its central role in Transformer models. Internally, it: - maps activations to `(queries, keys, values)` triples, - splits `queries`, `keys`, and `values` into multiple 'heads', - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, - uses attention weights to combine per-head `values` vectors, and - fuses per-head results into activations matching original input shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or 'eval'. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def test_select_second_of_3(self): layer = cb.Select([1], n_in=3) input_signature = (ShapeDtype((3, 2)), ShapeDtype( (4, 7)), ShapeDtype((11, 13))) expected_shape = (4, 7) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). This layer type represents one pass of multi-head self-attention, best known for its central role in Transformer models. Internally, it: - maps incoming sequence of activations to sequence of (query, key, value) triples, - splits queries, keys, and values into multiple 'heads', - computes per-head attention weights from per-head (queries, keys), - applies mask to screen out positions that come from padding tokens, - [in ``'train'`` mode] applies dropout to attention weights, - uses attention weights to combine per-head values vectors, and - fuses per-head results into outgoing activations matching original input activation shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (vectors, mask) to (new_vectors, mask). This layer type represents one pass of multi-head self-attention, from vector set to vector set, using masks to represent out-of-bound (e.g., padding) positions. It: - maps incoming sequence of activations vectors to sequence of (query, key, value) triples, - splits queries, keys, and values into multiple 'heads', - computes per-head attention weights from per-head (queries, keys), - applies mask to screen out positions that come from padding tokens, - [in ``'train'`` mode] applies dropout to attention weights, - uses attention weights to combine per-head values vectors, and - fuses per-head results into outgoing activations matching original input activation shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def RelativeAttentionLMLayer(d_feature, total_kv_pooling, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, chunk_len=None, chunk_offset=None, mode='train'): """Returns a layer that maps (q, k, v) to (activations). Same as standard Relative attention layer but additionally based on sizes of queries and keys prepares a mask that masks out the future. Masking the future is the concept primarily used for Language Modelling. Args: d_feature: Depth/dimensionality of feature embedding. total_kv_pooling: Accumulated pool size of keys/values used at this layer. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention mode: One of `'train'`, `'eval'`, or `'predict'`. """ attention = RelativeAttentionLayer( d_feature, total_kv_pooling, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, mode=mode) mask_layer = AttentionMaskLayer( total_kv_pooling=total_kv_pooling, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) return cb.Serial( cb.Branch( None, mask_layer, # vecs, mask ), attention, # vecs, mask cb.Select([0], n_in=2), # vecs )
def LSTM(n_units): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1), cb.Select([0], n_in=2) # Drop RNN state. )
def LSTM(n_units, mode='train', return_state=False, initial_state=False): """LSTM running on axis 1. Args: n_units: `n_units` for the `LSTMCell`. mode: if 'predict' then we save the previous state for one-by-one inference. return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. Returns: A LSTM layer. """ if not initial_state: zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter if return_state: return cb.Serial(cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Branch([], zero_state), # fill state RNN with zero. cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[]) else: if return_state: return cb.Serial(cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. name=f'LSTM_{n_units}', sublayers_to_print=[])
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling, resampling_fn): """Attention resampling.""" attention = RelativeAttentionLMLayer(d_model, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) resampling = resampling_fn(shorten_factor, d_model, mode=mode) def _Dropout(): return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ LayerNorm(), # h cb.Branch(cb.Serial( resampling, LayerNorm(), ), None), # h', h cb.Serial( # pylint: disable=g-long-ternary cb.Select([0, 2, 1, 2]), cb.Add(), ) if is_upsampling else [], cb.Residual( cb.Select([0, 1, 1]), # h', h, h attention, _Dropout(), ), cb.Residual( LayerNorm(), feed_forward, _Dropout(), ), ]
def GRU(n_units): """GRU running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=1) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(GRUCell(n_units=n_units), axis=1), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to GRU and don't print sublayers. name=f'GRU_{n_units}', sublayers_to_print=[] )
def LSTM(n_units, mode='train'): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[] )
def RelativeAttentionLMLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls=False, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, mode='train'): """Returns a layer that maps (q, k, v) to (activations). Same as standard Relative attention layer but additionally based on sizes of queries and keys prepares a mask that masks out the future. Masking the future is the concept primarily used for Language Modelling. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer. separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. mode: One of `'train'`, `'eval'`, or `'predict'`. """ attention = RelativeAttentionLayer( d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, mode=mode) return cb.Serial( AttentionMaskLayer(total_kv_pooling=total_kv_pooling, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, mode=mode), # q, k, v, mask attention, # vecs, mask cb.Select([0], n_in=2), # vecs )
def _WeightedMaskedMean(metric_layer, id_to_mask, has_weights): """Computes weighted masked mean of metric_layer(predictions, targets).""" multiply_by_weights = cb.Multiply() if has_weights else [] # Create a layer with 2 or 3 inputs: # - predictions targets (weights) # that applies the specified metric to a batch and gathers the results into # a single scalar. return cb.Serial( cb.Select([0, 1, 1]), cb.Parallel(metric_layer, _ElementMask(id_to_mask=id_to_mask)), cb.Parallel([], multiply_by_weights), # Stack now: metric_values weights _WeightedMean() )
def recombine(eqns, inputs, outputs): """Implement derived equations via layer-applications and combinators. Args: eqns: list of ApplyEqns derived from dataflow traces. inputs: list of strings representing input symbols outputs: list of strings representing output symbols Returns: Trax layer object that implements the given dataflow on provided layers. """ stack = tuple(inputs) # models the data stack layers = [] # output trax layers # Keep track of what variables are still needed after each # layer application so we can discard unnecessary variables # from the data stack. keepsets = [set(outputs)] for e in reversed(eqns): keepsets.append(keepsets[-1].union(e.src)) keepsets = list(reversed(keepsets[:-1])) # For each layer application, rearrange the data stack to supply # its inputs, copying arguments needed later on. for eqn, keep in zip(eqns, keepsets): remainder = tuple(s for s in stack if s in keep) # only insert data-routing layer if needed: if stack != eqn.src + remainder: select_indices = [stack.index(var) for var in eqn.src + remainder] layers.append(cb.Select(select_indices, len(stack))) # stack now equals eqn.src + remainder layers.append(eqn.lyr) stack = eqn.dst + remainder # Finally, if needed, select out the final outputs from the data stack. if stack != tuple(outputs): layers.append( cb.Select([stack.index(var) for var in outputs], len(stack))) return cb.Serial(*layers)
def SRU(n_units, activation=None, mode='train'): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. mode: if 'predict' then we save the previous state for one-by-one inference Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), ScanSRUCell(mode=mode), cb.Select([0], n_in=2), # act(c), r, x activation if activation is not None else [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def RelativeAttentionWrapper(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train', context_bias_layer=None, location_bias_layer=None, total_pooling=None): """Relative attention wrapper. Args: d_feature: Last/innermost dimension of activations in the input to and output from this layer. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. dropout: dropout rate. max_inference_length: max inference length. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. context_bias_layer: context bias layer. location_bias_layer: location bias layer. total_pooling: total pooling. Returns: relative attention layer. Relative attention wrapper for compatibility with configurable attention, so that it can be called by `ApplyAttentionLayer`. """ del max_inference_length attention = RelativeAttentionLMLayer(d_feature, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) return cb.Serial(cb.Select([0, 0, 0]), attention)
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`. This layer type represents one pass of multi-head self-attention, from vector set to vector set, using masks to represent out-of-bound (e.g., padding) positions. It: - makes three copies of incoming activations and maps these to multi-head query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; - for each head, computes the scaled dot product of each Q-K pair; - applies mask to screen out positions that come from padding tokens (indicated by 0 value); - [in ``'train'`` mode] applies dropout to Q-K dot products; - for each head, computes Q-K attention strengths using a per-query softmax of the Q-K dot products; - for each head, for each query position, combines V vectors according to the Q-K attention strengths; and - concatenates and fuses resulting per-head vectors into outgoing activations matching original input activation shapes. Args: d_feature: Last/innermost dimension of activations in the input to and output from this layer. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)))
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
def CreateAttentionMaskLayer(): """Creates attention mask layer. Returns a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates positional embeddings for causal relative attention calculations. Takes as input q, k, v and appends proper mask in the end. Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol. Returns: an attention mask layer. """ def calculate_mask(queries, keys): batch_size = queries.shape[0] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling) def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) return cb.Branch( cb.Select([0]), cb.Select([1]), cb.Select([2]), cb.Fn('create attention mask layer', calculate_mask, n_out=1))
def test_select_op_not_defined(self): input_shape = ((3, 2), (4, 7)) with self.assertRaises(AttributeError): cb.Select(1, input_shape)
def test_select_given_n_in(self): layer = cb.Select([0], n_in=2) self.assertEqual(layer.n_in, 2) layer = cb.Select([0], n_in=3) self.assertEqual(layer.n_in, 3)
def RelativeAttentionLayer(d_feature, total_kv_pooling, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, chunk_len=None, chunk_offset=None, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. total_kv_pooling: Accumulated pool size of keys/values used at this layer. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention mode: One of `'train'`, `'eval'`, or `'predict'`. """ pos_emb = PositionalEmbeddings( d_feature, total_kv_pooling, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) attention = RelativeAttention( # pylint: disable=no-value-for-parameter total_kv_pooling=total_kv_pooling, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, mode=mode), assert d_feature % n_heads == 0 d_head = d_feature // n_heads context_bias_layer = core.Weights( init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head)) location_bias_layer = core.Weights( init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head)) return cb.Serial( cb.Branch( cb.Serial(pos_emb, core.Dense(d_feature)), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), cb.Select([1]) # mask ), context_bias_layer, location_bias_layer, attention, core.Dense(d_feature), )