def _SparsifiableDense(layer_sparsity): if layer_sparsity is None: return core.Dense(d_feature) elif layer_sparsity == 'noop': return cb.Serial() # No-op layer. else: d_module = d_feature // layer_sparsity return cb.Serial( sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3) )
def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) input_signature = ShapeDtype((1, 32)) params1, _ = model1.initialize_once(input_signature) params2, _ = model2.initialize_once(input_signature) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def CountWeights(mask_id=None, has_weights=False): """Sum the weights assigned to all elements.""" if has_weights: return cb.Serial( cb.Drop(), # Drop inputs. WeightMask(mask_id=mask_id), # pylint: disable=no-value-for-parameter cb.Multiply(), # Multiply with provided mask. core.Sum(axis=None) # Sum all weights. ) return cb.Serial( cb.Drop(), # Drop inputs. WeightMask(mask_id=mask_id), # pylint: disable=no-value-for-parameter core.Sum(axis=None) # Sum all weights. )
def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) rng1, rng2 = backend.random.split(backend.random.get_prng(0), 2) params1, _ = model1.initialize_once((1, 32), onp.float32, rng1) params2, _ = model2.initialize_once((1, 32), onp.float32, rng2) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def test_dense_weight_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) input_signature = ShapeDtype((1, 32)) weights1, _ = model1.init(input_signature) weights2, _ = model2.init(input_signature) # The first weights have 2 kernels of size (32, 32). self.assertEqual((32, 32), weights1[0][0].shape) self.assertEqual((32, 32), weights1[1][0].shape) # The second weights have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), weights2[0][0].shape) self.assertEqual((), weights2[1])
def Accuracy(classifier=core.ArgMax()): """Returns a layer that computes mean category prediction accuracy.""" return cb.Serial(classifier, _Accuracy(), _WeightedMean(), name='Accuracy', sublayers_to_print=[])
def SequenceAccuracy(classifier=core.ArgMax()): """Returns a layer that computes mean sequence prediction accuracy.""" return cb.Serial(classifier, _Accuracy(), _WeightedSequenceMean(), name='SequenceAccuracy', sublayers_to_print=[])
def AssertFunction(specification, layer, message=None): # pylint: disable=invalid-name """AssertFunction asserts shapes on the input/output tensors of a layer. It passes all inputs to the layer, and returns all outputs of the layer unchanged. Args: specification: A specification. See assert_shape decorator for a full documentation. layer: A base.Layer to wrap around. message: An optional message to print if an assert fails. By default it will print the filename and the line number where AssertFunction was called. Returns: The given layer wrapped in asserts on its inputs and outputs. """ if message is None: caller = inspect.getframeinfo(inspect.stack()[1][0]) message = f'Defined at {caller.filename}:{caller.lineno}' before_spec, after_spec = specification.split('->') before_assert = AssertShape(before_spec, message=message + ' function input') after_assert = AssertShape(after_spec, message=message + ' function output') after_assert._create_link(before_assert) # pylint: disable=protected-access return combinators.Serial(before_assert, layer, after_assert)
def SumOfWeights(): """Returns a layer that computes sum of weights.""" return cb.Serial( cb.Drop(), # Drop inputs. cb.Drop(), # Drop targets. core.Sum(axis=None) # Sum weights. )
def _WeightedMaskedMean(metric_layer, final_layer_override=None): """Computes weighted masked mean of metric_layer(predictions, targets).""" final_layer = final_layer_override or _WeightedMean() # For sequence acc. return cb.Serial( metric_layer, final_layer )
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). This layer type represents one pass of multi-head self-attention, best known for its central role in Transformer models. Internally, it: - maps incoming sequence of activations to sequence of (query, key, value) triples, - splits queries, keys, and values into multiple 'heads', - computes per-head attention weights from per-head (queries, keys), - applies mask to screen out positions that come from padding tokens, - [in ``'train'`` mode] applies dropout to attention weights, - uses attention weights to combine per-head values vectors, and - fuses per-head results into outgoing activations matching original input activation shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def QKVLayer(): """Function returning the Q, K and V layer.""" if use_dconv: return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) else: return core.Dense(d_feature)
def CrossEntropyLossWithLogSoftmax(): """Mean prediction-target cross-entropy for multiclass classification.""" return cb.Serial(core.LogSoftmax(), _CrossEntropy(), _WeightedMean(), name='CrossEntropyLossWithLogSoftmax', sublayers_to_print=[])
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. This is used, for instance, in encoder-decoder attention (Q-related activations `AQ` from the decoder, K- and V-related activations -- `AK` and `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` description for further context/details. Args: d_feature: Last/innermost dimension of activations in the input to and output from this layer. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. q_sparsity: Sparsity with which to process queries. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. result_sparsity: Sparsity with which to process result of the attention. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. """ def _SparsifiableDense(layer_sparsity): if layer_sparsity is None: return core.Dense(d_feature) elif layer_sparsity == 'noop': return cb.Serial() # No-op layer. else: d_module = d_feature // layer_sparsity return cb.Serial( sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3) ) def _CacheableDense(): if cache_KV_in_predict and mode == 'predict': return cb.Cache(core.Dense(d_feature)) else: return core.Dense(d_feature) def _PureAttention(): return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) return cb.Serial( cb.Parallel(_SparsifiableDense(q_sparsity), _CacheableDense(), _CacheableDense()), _PureAttention(), _SparsifiableDense(result_sparsity), )
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (vectors, mask) to (new_vectors, mask). This layer type represents one pass of multi-head self-attention, from vector set to vector set, using masks to represent out-of-bound (e.g., padding) positions. It: - maps incoming sequence of activations vectors to sequence of (query, key, value) triples, - splits queries, keys, and values into multiple 'heads', - computes per-head attention weights from per-head (queries, keys), - applies mask to screen out positions that come from padding tokens, - [in ``'train'`` mode] applies dropout to attention weights, - uses attention weights to combine per-head values vectors, and - fuses per-head results into outgoing activations matching original input activation shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ return cb.Serial( cb.Select([0, 0, 0]), AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form q, k, v, mask. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). This layer type represents one pass of multi-head self-attention, best known for its central role in Transformer models. Internally, it: - maps activations to `(queries, keys, values)` triples, - splits `queries`, `keys`, and `values` into multiple 'heads', - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, - uses attention weights to combine per-head `values` vectors, and - fuses per-head results into activations matching original input shapes. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or 'eval'. """ return cb.Serial( cb.Dup(), cb.Dup(), # TODO(jonni): replace with Select([0, 0, 0]) AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), )
def SumOfWeights(): """Returns a layer to compute sum of weights of all non-masked elements.""" return cb.Serial( cb.Drop(), # Drop inputs. cb.Drop(), # Drop targets. core.Sum(axis=None) # Sum weights. )
def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Dup(), cb.Dup(), cb.Parallel(memory_transform, gate_block, candidate_block), cb.Gate(), )
def RelativeAttentionLMLayer(d_feature, total_kv_pooling, n_heads=1, dropout=0.0, n_raw_tokens_generated=1, max_inference_length=3072, chunk_len=None, chunk_offset=None, mode='train'): """Returns a layer that maps (q, k, v) to (activations). Same as standard Relative attention layer but additionally based on sizes of queries and keys prepares a mask that masks out the future. Masking the future is the concept primarily used for Language Modelling. Args: d_feature: Depth/dimensionality of feature embedding. total_kv_pooling: Accumulated pool size of keys/values used at this layer. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. n_raw_tokens_generated: Number of tokens generated in a single pass through this layer. Used only in 'predict' non-training mode. max_inference_length: Maximum sequence length allowed in non-training modes. chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention mode: One of `'train'`, `'eval'`, or `'predict'`. """ attention = RelativeAttentionLayer( d_feature, total_kv_pooling, n_heads=n_heads, dropout=dropout, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, mode=mode) mask_layer = AttentionMaskLayer( total_kv_pooling=total_kv_pooling, max_inference_length=max_inference_length, chunk_len=chunk_len, chunk_offset=chunk_offset, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) return cb.Serial( cb.Branch( None, mask_layer, # vecs, mask ), attention, # vecs, mask cb.Select([0], n_in=2), # vecs )
def CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `Attention`, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return Fn('SplitIntoHeads', f) def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f) return cb.Serial( cb.Branch( [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], [core.Dense(d_feature), _split_into_heads()], ), DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode), _merge_heads(), core.Dense(d_feature), )
def LSTM(n_units): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1), cb.Select([0], n_in=2) # Drop RNN state. )
def RelativeAttentionLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Branch( PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), cb.Select([0]), cb.Select([1])), cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), context_bias_layer, location_bias_layer, RelativeAttention( # pylint: disable=no-value-for-parameter separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling, resampling_fn): """Attention resampling.""" attention = RelativeAttentionLMLayer(d_model, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) resampling = resampling_fn(shorten_factor, d_model, mode=mode) def _Dropout(): return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ LayerNorm(), # h cb.Branch(cb.Serial( resampling, LayerNorm(), ), None), # h', h cb.Serial( # pylint: disable=g-long-ternary cb.Select([0, 2, 1, 2]), cb.Add(), ) if is_upsampling else [], cb.Residual( cb.Select([0, 1, 1]), # h', h, h attention, _Dropout(), ), cb.Residual( LayerNorm(), feed_forward, _Dropout(), ), ]
def LSTM(n_units, mode='train', return_state=False, initial_state=False): """LSTM running on axis 1. Args: n_units: `n_units` for the `LSTMCell`. mode: if 'predict' then we save the previous state for one-by-one inference. return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. Returns: A LSTM layer. """ if not initial_state: zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter if return_state: return cb.Serial(cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Branch([], zero_state), # fill state RNN with zero. cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[]) else: if return_state: return cb.Serial(cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. name=f'LSTM_{n_units}', sublayers_to_print=[])
def test_input_signatures_serial(self): layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0)) self.assertIsNone(layer.input_signature) layer.input_signature = ShapeDtype((3, 2)) self.assertEqual(layer.input_signature, ShapeDtype((3, 2))) self.assertLen(layer.sublayers, 2) for sublayer in layer.sublayers: self.assertEqual(sublayer.input_signature, ShapeDtype((3, 2)))
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): """Positional embedding for relative attention. Returns a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates sinusoidal positional embeddings for relative attention calculations. Args: d_feature: Depth/dimensionality of feature embedding. separate_cls: True/False if we separate_cls in calculations. total_kv_pooling: Accumulated pool size of keys/values until this layer. Returns: Positional embedding. """ def PositionsVectors(queries, keys): is_funnel_layer = queries.shape != keys.shape keys_len, queries_len = keys.shape[1], queries.shape[1] current_pooling_ratio = keys_len / queries_len # Special case of upsampling if is_funnel_layer and current_pooling_ratio < 1: # We should not be doing standard upsampling when we use separate_cls # Cls token is being used for classification assert not separate_cls assert (total_kv_pooling * keys_len) % queries_len == 0 multiplier = ((total_kv_pooling * keys_len) // queries_len) positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier else: positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling if is_funnel_layer and separate_cls: # For pool_size 2 without separating cls we have got # [0][1][2][3][4][5][6][7] -> [01][23][45][67] # With separating cls we have got # [0][1][2][3][4][5][6][7] -> [0][12][34][56] # First group always will always consist of one token after pooling # instead of (pool_size) tokens. We need to add proper offset so # that our shift later on in calculating attention works properly cls_offset = (current_pooling_ratio - 1) * total_kv_pooling positions = positions + cls_offset return positions def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb return cb.Serial( cb.Fn('Generate positions vectors', PositionsVectors, n_out=1), cb.Fn( 'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1))
def test_input_signatures_serial(self): layer = cb.Serial(divide_by(2.0), divide_by(5.0)) self.assertIsNone(layer.input_signature) layer._set_input_signature_recursive(ShapeDtype((3, 2))) self.assertEqual(layer.input_signature, ShapeDtype((3, 2))) self.assertLen(layer.sublayers, 2) for sublayer in layer.sublayers: self.assertEqual(sublayer.input_signature, ShapeDtype((3, 2)))
def SumOfWeights(id_to_mask=None, has_weights=False): """Returns a layer to compute sum of weights of all non-masked elements.""" multiply_by_weights = cb.Multiply() if has_weights else [] return cb.Serial( cb.Drop(), # Drop inputs. _ElementMask(id_to_mask=id_to_mask), multiply_by_weights, core.Sum(axis=None) # Sum all. )