Esempio n. 1
0
 def test_branch_named(self):
     input_shape = (2, 3)
     expected_shape = {'a': (2, 3), 'b': (2, 3)}
     output_shape = base.check_shape_agreement(
         combinators.Branch(a=combinators.NoOp(), b=combinators.NoOp()),
         input_shape)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 2
0
 def test_parallel(self):
     input_shape = ((2, 3), (2, 3))
     expected_shape = ((2, 3), (2, 3))
     output_shape = base.check_shape_agreement(
         combinators.Parallel(combinators.NoOp(), combinators.NoOp()),
         input_shape)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 3
0
def MultiHeadedAttentionQKV(
    feature_depth, num_heads=8, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.

  Accepts inputs of the form (q, k, v), mask.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
  return combinators.Serial(
      combinators.Parallel(
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          combinators.NoOp()
      ),
      PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
          feature_depth=feature_depth, num_heads=num_heads,
          dropout=dropout, mode=mode),
      combinators.Parallel(core.Dense(feature_depth), combinators.NoOp())
  )
Esempio n. 4
0
 def test_parallel_named(self):
     input_shape = {'a': (2, 3), 'b': (2, 3)}
     expected_shape = {'a': (2, 3), 'b': (2, 3)}
     output_shape = base.check_shape_agreement(
         combinators.Parallel(a=combinators.NoOp()), input_shape)
     self.assertEqual(output_shape, expected_shape)