def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def CausalAttention(d_feature, n_heads=1, d_attention_key=None, d_attention_value=None, attention_type=DotProductCausalAttention, share_qk=False, mode='train'): """Transformer-style multi-headed causal attention. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_feature // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_feature // n_heads) attention_type: subclass of BaseCausalAttention: attention class to use share_qk: bool, whether to share queries and keys mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result. """ if d_attention_key is None: assert d_feature % n_heads == 0 d_attention_key = d_feature // n_heads if d_attention_value is None: assert d_feature % n_heads == 0 d_attention_value = d_feature // n_heads if share_qk: pre_attention = [ cb.Dup(), cb.Parallel( ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), cb.Dup(), ] else: pre_attention = [ cb.Dup(), cb.Dup(), cb.Parallel( ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] return cb.Serial(pre_attention + [ attention_type(mode=mode), ComputeAttentionOutput(n_heads=n_heads, d_model=d_feature), ])
def _WeightedMaskedMean(metric_layer, id_to_mask, has_weights): """Computes weighted masked mean of metric_layer(predictions, targets).""" multiply_by_weights = cb.Multiply() if has_weights else [] # Create a layer with 2 or 3 inputs: # - predictions targets (weights) # that applies the specified metric to a batch and gathers the results into # a single scalar. return cb.Serial( cb.Select([0, 1, 1]), cb.Parallel(metric_layer, _ElementMask(id_to_mask=id_to_mask)), cb.Parallel([], multiply_by_weights), # Stack now: metric_values weights _WeightedMean() )
def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr @ (a, b) e = add_lyr @ (d, c) f, g = tanh_lyr @ (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. This is used, for instance, in encoder-decoder attention (Q-related activations `AQ` from the decoder, K- and V-related activations -- `AK` and `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` description for further context/details. Args: d_feature: Last/innermost dimension of activations in the input to and output from this layer. n_heads: Number of attention heads. Attention heads effectively split activation vectors into ``n_heads`` subvectors, of size ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. q_sparsity: Sparsity with which to process queries. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. result_sparsity: Sparsity with which to process result of the attention. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. """ def _SparsifiableDense(layer_sparsity): if layer_sparsity is None: return core.Dense(d_feature) elif layer_sparsity == 'noop': return cb.Serial() # No-op layer. else: d_module = d_feature // layer_sparsity return cb.Serial( sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3) ) def _CacheableDense(): if cache_KV_in_predict and mode == 'predict': return cb.Cache(core.Dense(d_feature)) else: return core.Dense(d_feature) def _PureAttention(): return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) return cb.Serial( cb.Parallel(_SparsifiableDense(q_sparsity), _CacheableDense(), _CacheableDense()), _PureAttention(), _SparsifiableDense(result_sparsity), )
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form q, k, v, mask. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(core.Relu(), core.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr << (a, b) e = add_lyr << (d, c) f, g = tanh_lyr << (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) p, s = layer.new_params_and_state( ((2, 10), (2, 10), (2, 10)), (onp.float32, onp.float32, onp.float32), rng=jax.random.PRNGKey(0)) res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Dup(), cb.Dup(), cb.Parallel(memory_transform, gate_block, candidate_block), cb.Gate(), )
def RelativeAttentionLayer(d_feature, context_bias_layer, location_bias_layer, total_kv_pooling, separate_cls, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, masks) to (activations, masks). When number of keys is smaller than number of queries layer works in O(q^2*d). Otherwise it is O(q*k*d). That is because we need to shift relative distances by current_pooling. When we upsample this is current pooling is a fraction < 1 Visual explanation: [01][23][45][67] -> [0][1][2][3][4][5][6][7] For token [0] we calculate relative distances as follows: * 0 2 4 6 However for token [1] we need relative distances changed by 1, specifically: * -1 1 3 5 So we not only need to calculate the distances that corresponds to spacing between the keys but also for the ones in between because there are more than one query tokens (on different positions which means different relative distances) for single key token. Args: d_feature: Depth/dimensionality of feature embedding. context_bias_layer: Global context bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers location_bias_layer: Global location bias from Transformer XL's attention. There should be one such layer shared for all relative attention layers. total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. n_heads: Number of attention heads. dropout: Probabilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Branch( PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), cb.Select([0]), cb.Select([1])), cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), context_bias_layer, location_bias_layer, RelativeAttention( # pylint: disable=no-value-for-parameter separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def test_tracer_index(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) a = tracer.Tracer('a') b = tracer.Tracer('b') d, e = lyr @ (a, b) result0 = tracer.IndexExpr(0, tracer.ApplyExpr(lyr, ('a', 'b'))) result1 = tracer.IndexExpr(1, tracer.ApplyExpr(lyr, ('a', 'b'))) self.assertEqual(d.expr, result0) self.assertEqual(e.expr, result1)
def test_eqns_merge_outputs(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) eqns = [ tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )), tracer.IndexEqn(0, 'var2', 'var0'), tracer.IndexEqn(1, 'var2', 'var1') ] simple_eqns = tracer.merge_output_tuples(eqns) result = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', 'var1'))] self.assertEqual(simple_eqns, result)
def SRU(n_units, activation=None, mode='train'): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. mode: if 'predict' then we save the previous state for one-by-one inference Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), ScanSRUCell(mode=mode), cb.Select([0], n_in=2), # act(c), r, x activation if activation is not None else [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def test_input_signatures_parallel(self): layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0)) self.assertIsNone(layer.input_signature) layer.input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) self.assertEqual(layer.input_signature, (ShapeDtype( (3, 2)), ShapeDtype((4, 7)))) self.assertLen(layer.sublayers, 2) sublayer_0, sublayer_1 = layer.sublayers self.assertEqual(sublayer_0.input_signature, ShapeDtype((3, 2))) self.assertEqual(sublayer_1.input_signature, ShapeDtype((4, 7)))
def test_input_signatures_parallel(self): layer = cb.Parallel(divide_by(0.5), divide_by(3.0)) self.assertIsNone(layer.input_signature) layer._set_input_signature_recursive((ShapeDtype( (3, 2)), ShapeDtype((4, 7)))) self.assertEqual(layer.input_signature, (ShapeDtype( (3, 2)), ShapeDtype((4, 7)))) self.assertLen(layer.sublayers, 2) sublayer_0, sublayer_1 = layer.sublayers self.assertEqual(sublayer_0.input_signature, ShapeDtype((3, 2))) self.assertEqual(sublayer_1.input_signature, ShapeDtype((4, 7)))
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)))
def test_apply_index_to_eqn(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) a = tracer.Tracer('a') b = tracer.Tracer('b') c, d = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns((c, d)) result0 = [ tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )), tracer.IndexEqn(0, 'var2', 'var0'), tracer.IndexEqn(1, 'var2', 'var1') ] result1 = ('var0', 'var1') self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
def MaskedScalar(metric_layer, mask_id=None, has_weights=False): """Metric as scalar compatible with Trax masking.""" # Stack of (inputs, targets) --> (metric, weight-mask). metric_and_mask = [ cb.Parallel( [], cb.Dup() # Duplicate targets ), cb.Parallel( metric_layer, # Metric: (inputs, targets) --> metric WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter ) ] if not has_weights: # Take (metric, weight-mask) and return the weighted mean. return cb.Serial(metric_and_mask, WeightedMean()) # pylint: disable=no-value-for-parameter return cb.Serial( metric_and_mask, cb.Parallel( [], cb.Multiply() # Multiply given weights by mask_id weights ), WeightedMean() # pylint: disable=no-value-for-parameter )
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, mask) to (activations, mask). Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or 'eval'. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention(n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See `Attention` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return cb.Serial( cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), )
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See ``Attention`` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V tensors in predict mode. q_sparsity: Sparsity with which to process queries. If None, Dense is used. If 'noop' then no processing is used. result_sparsity: Sparsity with which to process result of the attention. If None, Dense is used. If 'noop' then no processing is used. """ k_processor = core.Dense(d_feature) v_processor = core.Dense(d_feature) if cache_KV_in_predict and mode == 'predict': k_processor = cb.Cache(k_processor) v_processor = cb.Cache(v_processor) if q_sparsity is None: q_processor = core.Dense(d_feature) elif q_sparsity == 'noop': q_processor = cb.Serial() else: d_module = d_feature // q_sparsity q_processor = cb.Serial( sparsity.MultiplicativeSparseDense(q_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(q_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) if result_sparsity is None: result_processor = core.Dense(d_feature) elif result_sparsity == 'noop': result_processor = cb.Serial() else: d_module = d_feature // result_sparsity result_processor = cb.Serial( sparsity.MultiplicativeSparseDense(result_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(result_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) return cb.Serial( cb.Parallel( q_processor, k_processor, v_processor, ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), result_processor)
def some_layer(): return cb.Parallel(divide_by(2.0), divide_by(5.0))
def test_state_parallel(self): model = cb.Parallel(core.Dense(3), core.Dense(5)) self.assertIsInstance(model.state, tuple) self.assertLen(model.state, 2)
def test_parallel_no_ops(self): layer = cb.Parallel([], None) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_parallel_div_div(self): layer = cb.Parallel(divide_by(0.5), divide_by(3.0)) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_parallel_dup_dup(self): layer = cb.Parallel(cb.Dup(), cb.Dup()) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_weights_parallel(self): model = cb.Parallel(core.Dense(3), core.Dense(5)) self.assertIsInstance(model.weights, tuple) self.assertLen(model.weights, 2)
def test_parallel_div_div(self): layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0)) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_custom_name(self): layer = cb.Parallel(cb.Dup(), cb.Dup()) # pylint: disable=no-value-for-parameter self.assertIn('Parallel', str(layer)) layer = cb.Parallel(cb.Dup(), cb.Dup(), name='DupDup') # pylint: disable=no-value-for-parameter self.assertIn('DupDup', str(layer))
def some_layer(): return cb.Parallel(core.Div(divisor=2.0), core.Div(divisor=5.0))