示例#1
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
示例#2
0
def CausalAttention(d_feature,
                    n_heads=1,
                    d_attention_key=None,
                    d_attention_value=None,
                    attention_type=DotProductCausalAttention,
                    share_qk=False,
                    mode='train'):
    """Transformer-style multi-headed causal attention.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_feature // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_feature // n_heads)
    attention_type: subclass of BaseCausalAttention: attention class to use
    share_qk: bool, whether to share queries and keys
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result.
  """
    if d_attention_key is None:
        assert d_feature % n_heads == 0
        d_attention_key = d_feature // n_heads
    if d_attention_value is None:
        assert d_feature % n_heads == 0
        d_attention_value = d_feature // n_heads

    if share_qk:
        pre_attention = [
            cb.Dup(),
            cb.Parallel(
                ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
                ComputeAttentionHeads(n_heads=n_heads,
                                      d_head=d_attention_value),
            ),
            cb.Dup(),
        ]
    else:
        pre_attention = [
            cb.Dup(),
            cb.Dup(),
            cb.Parallel(
                ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
                ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
                ComputeAttentionHeads(n_heads=n_heads,
                                      d_head=d_attention_value),
            ),
        ]

    return cb.Serial(pre_attention + [
        attention_type(mode=mode),
        ComputeAttentionOutput(n_heads=n_heads, d_model=d_feature),
    ])
示例#3
0
def _WeightedMaskedMean(metric_layer, id_to_mask, has_weights):
  """Computes weighted masked mean of metric_layer(predictions, targets)."""
  multiply_by_weights = cb.Multiply() if has_weights else []
  # Create a layer with 2 or 3 inputs:
  #   - predictions targets (weights)
  # that applies the specified metric to a batch and gathers the results into
  # a single scalar.
  return cb.Serial(
      cb.Select([0, 1, 1]),
      cb.Parallel(metric_layer, _ElementMask(id_to_mask=id_to_mask)),
      cb.Parallel([], multiply_by_weights),  # Stack now: metric_values weights
      _WeightedMean()
  )
示例#4
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr @ (a, b)
            e = add_lyr @ (d, c)
            f, g = tanh_lyr @ (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
示例#5
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train',
                 cache_KV_in_predict=False, q_sparsity=None,
                 result_sparsity=None):
  """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`.

  Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the
  incoming activations (`AQ`, `AK`, and `AV`) to come from different sources.
  This is used, for instance, in encoder-decoder attention (Q-related
  activations `AQ` from the decoder, K- and V-related activations -- `AK` and
  `AV` -- from the encoder). Otherwise, see the :py:class:`Attention`
  description for further context/details.

  Args:
    d_feature: Last/innermost dimension of activations in the input to and
        output from this layer.
    n_heads: Number of attention heads. Attention heads effectively split
        activation vectors into ``n_heads`` subvectors, of size
        ``d_feature / n_heads``.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode.
    q_sparsity: Sparsity with which to process queries. If ``None``,
        :py:class:`Dense` is used; if ``'noop'``, no processing is used.
    result_sparsity: Sparsity with which to process result of the attention.
        If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is
        used.
  """
  def _SparsifiableDense(layer_sparsity):
    if layer_sparsity is None:
      return core.Dense(d_feature)
    elif layer_sparsity == 'noop':
      return cb.Serial()  # No-op layer.
    else:
      d_module = d_feature // layer_sparsity
      return cb.Serial(
          sparsity.FactoredDense(layer_sparsity, d_feature, d_feature),
          sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode,
                                    kernel_size=3, length_kernel_size=3)
      )

  def _CacheableDense():
    if cache_KV_in_predict and mode == 'predict':
      return cb.Cache(core.Dense(d_feature))
    else:
      return core.Dense(d_feature)

  def _PureAttention():
    return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode)

  return cb.Serial(
      cb.Parallel(_SparsifiableDense(q_sparsity),
                  _CacheableDense(),
                  _CacheableDense()),
      _PureAttention(),
      _SparsifiableDense(result_sparsity),
  )
示例#6
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Transformer-style multi-headed attention.

  Accepts inputs of the form q, k, v, mask.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
    return cb.Serial(
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
示例#7
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(core.Relu(), core.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            f, g = tanh_lyr << (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        p, s = layer.new_params_and_state(
            ((2, 10), (2, 10), (2, 10)),
            (onp.float32, onp.float32, onp.float32),
            rng=jax.random.PRNGKey(0))
        res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
示例#8
0
def GeneralGRUCell(candidate_transform,
                   memory_transform_fn=None,
                   gate_nonlinearity=core.Sigmoid,
                   candidate_nonlinearity=core.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
  r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.Gate for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
  gate_block = [  # u_t
      candidate_transform(),
      core.AddConstant(constant=sigmoid_bias),
      gate_nonlinearity(),
  ]
  reset_block = [  # r_t
      candidate_transform(),
      core.AddConstant(constant=sigmoid_bias),  # Want bias to start positive.
      gate_nonlinearity(),
  ]
  candidate_block = [
      cb.Dup(),
      reset_block,
      cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
      candidate_transform(),  # Final projection + tanh to get Ct
      candidate_nonlinearity(),  # Candidate gate

      # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
      core.Dropout(rate=dropout_rate_c)
  ]
  memory_transform = memory_transform_fn() if memory_transform_fn else []
  return cb.Serial(
      cb.Dup(), cb.Dup(),
      cb.Parallel(memory_transform, gate_block, candidate_block),
      cb.Gate(),
  )
示例#9
0
def RelativeAttentionLayer(d_feature,
                           context_bias_layer,
                           location_bias_layer,
                           total_kv_pooling,
                           separate_cls,
                           n_heads=1,
                           dropout=0.0,
                           mode='train'):
    """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer
    separate_cls: True/False if we separate_cls in calculations.

    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return cb.Serial(
        cb.Branch(
            PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling),
            cb.Select([0]), cb.Select([1])),
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        context_bias_layer,
        location_bias_layer,
        RelativeAttention(  # pylint: disable=no-value-for-parameter
            separate_cls=separate_cls,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
示例#10
0
 def test_tracer_index(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     d, e = lyr @ (a, b)
     result0 = tracer.IndexExpr(0, tracer.ApplyExpr(lyr, ('a', 'b')))
     result1 = tracer.IndexExpr(1, tracer.ApplyExpr(lyr, ('a', 'b')))
     self.assertEqual(d.expr, result0)
     self.assertEqual(e.expr, result1)
示例#11
0
 def test_eqns_merge_outputs(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     eqns = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     simple_eqns = tracer.merge_output_tuples(eqns)
     result = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', 'var1'))]
     self.assertEqual(simple_eqns, result)
示例#12
0
def SRU(n_units, activation=None, mode='train'):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        ScanSRUCell(mode=mode),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation if activation is not None else [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
    def test_input_signatures_parallel(self):
        layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0))
        self.assertIsNone(layer.input_signature)

        layer.input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
        self.assertEqual(layer.input_signature, (ShapeDtype(
            (3, 2)), ShapeDtype((4, 7))))
        self.assertLen(layer.sublayers, 2)
        sublayer_0, sublayer_1 = layer.sublayers
        self.assertEqual(sublayer_0.input_signature, ShapeDtype((3, 2)))
        self.assertEqual(sublayer_1.input_signature, ShapeDtype((4, 7)))
示例#14
0
    def test_input_signatures_parallel(self):
        layer = cb.Parallel(divide_by(0.5), divide_by(3.0))
        self.assertIsNone(layer.input_signature)

        layer._set_input_signature_recursive((ShapeDtype(
            (3, 2)), ShapeDtype((4, 7))))
        self.assertEqual(layer.input_signature, (ShapeDtype(
            (3, 2)), ShapeDtype((4, 7))))
        self.assertLen(layer.sublayers, 2)
        sublayer_0, sublayer_1 = layer.sublayers
        self.assertEqual(sublayer_0.input_signature, ShapeDtype((3, 2)))
        self.assertEqual(sublayer_1.input_signature, ShapeDtype((4, 7)))
示例#15
0
文件: rnn.py 项目: zhaoqiuye/trax
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)))
示例#16
0
 def test_apply_index_to_eqn(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c, d = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns((c, d))
     result0 = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
示例#17
0
文件: rnn.py 项目: jackalhan/trax
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
示例#18
0
def MaskedScalar(metric_layer, mask_id=None, has_weights=False):
    """Metric as scalar compatible with Trax masking."""
    # Stack of (inputs, targets) --> (metric, weight-mask).
    metric_and_mask = [
        cb.Parallel(
            [],
            cb.Dup()  # Duplicate targets
        ),
        cb.Parallel(
            metric_layer,  # Metric: (inputs, targets) --> metric
            WeightMask(mask_id=mask_id)  # pylint: disable=no-value-for-parameter
        )
    ]
    if not has_weights:
        # Take (metric, weight-mask) and return the weighted mean.
        return cb.Serial(metric_and_mask, WeightedMean())  # pylint: disable=no-value-for-parameter
    return cb.Serial(
        metric_and_mask,
        cb.Parallel(
            [],
            cb.Multiply()  # Multiply given weights by mask_id weights
        ),
        WeightedMean()  # pylint: disable=no-value-for-parameter
    )
示例#19
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).
    Args:
      d_feature: Depth/dimensionality of feature embedding.
      n_heads: Number of attention heads.
      dropout: Probababilistic rate for internal dropout applied to attention
          activations (based on query-key pairs) before dotting them with values.
      mode: Either 'train' or 'eval'.
    """
    return cb.Serial(
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(n_heads=n_heads, dropout=dropout, mode=mode),
        core.Dense(d_feature),
    )
示例#20
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See `Attention` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )
示例#21
0
def AttentionQKV(d_feature,
                 n_heads=1,
                 dropout=0.0,
                 mode='train',
                 cache_KV_in_predict=False,
                 q_sparsity=None,
                 result_sparsity=None):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See ``Attention`` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V tensors in predict mode.
    q_sparsity: Sparsity with which to process queries. If None, Dense is
        used. If 'noop' then no processing is used.
    result_sparsity: Sparsity with which to process result of the attention. If
        None, Dense is used. If 'noop' then no processing is used.
  """
    k_processor = core.Dense(d_feature)
    v_processor = core.Dense(d_feature)
    if cache_KV_in_predict and mode == 'predict':
        k_processor = cb.Cache(k_processor)
        v_processor = cb.Cache(v_processor)

    if q_sparsity is None:
        q_processor = core.Dense(d_feature)
    elif q_sparsity == 'noop':
        q_processor = cb.Serial()
    else:
        d_module = d_feature // q_sparsity
        q_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(q_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(q_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    if result_sparsity is None:
        result_processor = core.Dense(d_feature)
    elif result_sparsity == 'noop':
        result_processor = cb.Serial()
    else:
        d_module = d_feature // result_sparsity
        result_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(result_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(result_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    return cb.Serial(
        cb.Parallel(
            q_processor,
            k_processor,
            v_processor,
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        result_processor)
示例#22
0
 def some_layer():
     return cb.Parallel(divide_by(2.0), divide_by(5.0))
示例#23
0
 def test_state_parallel(self):
     model = cb.Parallel(core.Dense(3), core.Dense(5))
     self.assertIsInstance(model.state, tuple)
     self.assertLen(model.state, 2)
示例#24
0
 def test_parallel_no_ops(self):
     layer = cb.Parallel([], None)
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
示例#25
0
 def test_parallel_div_div(self):
     layer = cb.Parallel(divide_by(0.5), divide_by(3.0))
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
示例#26
0
 def test_parallel_dup_dup(self):
     layer = cb.Parallel(cb.Dup(), cb.Dup())
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
     expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
示例#27
0
 def test_weights_parallel(self):
     model = cb.Parallel(core.Dense(3), core.Dense(5))
     self.assertIsInstance(model.weights, tuple)
     self.assertLen(model.weights, 2)
示例#28
0
 def test_parallel_div_div(self):
     layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0))
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
示例#29
0
  def test_parallel_custom_name(self):
    layer = cb.Parallel(cb.Dup(), cb.Dup())  # pylint: disable=no-value-for-parameter
    self.assertIn('Parallel', str(layer))

    layer = cb.Parallel(cb.Dup(), cb.Dup(), name='DupDup')  # pylint: disable=no-value-for-parameter
    self.assertIn('DupDup', str(layer))
 def some_layer():
     return cb.Parallel(core.Div(divisor=2.0), core.Div(divisor=5.0))