Пример #1
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Transformer-style multi-headed attention.

  Accepts inputs of the form q, k, v, mask.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
    return [
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    ]
Пример #2
0
def MultiHeadedAttentionQKV(
    feature_depth, num_heads=8, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.

  Accepts inputs of the form (q, k, v), mask.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
  return combinators.Serial(
      combinators.Parallel(
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          combinators.NoOp()
      ),
      PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
          feature_depth=feature_depth, num_heads=num_heads,
          dropout=dropout, mode=mode),
      combinators.Parallel(core.Dense(feature_depth), combinators.NoOp())
  )
Пример #3
0
 def test_dense_param_sharing(self):
     model1 = combinators.Serial(core.Dense(32), core.Dense(32))
     layer = core.Dense(32)
     model2 = combinators.Serial(layer, layer)
     rng = backend.random.get_prng(0)
     params1 = model1.initialize((1, 32), onp.float32, rng)
     params2 = model2.initialize((1, 32), onp.float32, rng)
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
Пример #4
0
def ChunkedCausalMultiHeadedAttention(feature_depth,
                                      num_heads=8,
                                      dropout=0.0,
                                      chunk_selector=None,
                                      mode='train'):
    """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    prepare_attention_input = combinators.Serial(
        combinators.Branch(
            combinators.Branch(  # q = k = v = first input
                combinators.Copy(), combinators.Copy(), combinators.Copy()),
            CausalMask(axis=-2),  # pylint: disable=no-value-for-parameter
        ),
        combinators.Parallel(
            combinators.Parallel(
                core.Dense(feature_depth),
                core.Dense(feature_depth),
                core.Dense(feature_depth),
            ), combinators.Copy()))
    return combinators.Serial(
        combinators.Map(prepare_attention_input),
        ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
        combinators.Map(
            PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
                feature_depth=feature_depth,
                num_heads=num_heads,
                dropout=dropout,
                mode=mode),
            check_shapes=False),
        combinators.Map(combinators.Select(0),
                        check_shapes=False),  # drop masks
        combinators.Map(core.Dense(feature_depth)))
Пример #5
0
def GRUCell(units):
    """Builds a traditional GRU cell with dense internal transformations.

  Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555


  Args:
    units: Number of hidden units.

  Returns:
    A Stax model representing a traditional GRU RNN cell.
  """
    return GeneralGRUCell(candidate_transform=lambda: core.Dense(units=units),
                          memory_transform=cb.NoOp,
                          gate_nonlinearity=core.Sigmoid,
                          candidate_nonlinearity=core.Tanh)