Exemple #1
0
def ChunkedCausalMultiHeadedAttention(feature_depth,
                                      num_heads=8,
                                      dropout=0.0,
                                      chunk_selector=None,
                                      mode='train'):
    """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    prepare_attention_input = combinators.Serial(
        combinators.Branch(
            combinators.Branch(  # q = k = v = first input
                combinators.Copy(), combinators.Copy(), combinators.Copy()),
            CausalMask(axis=-2),  # pylint: disable=no-value-for-parameter
        ),
        combinators.Parallel(
            combinators.Parallel(
                core.Dense(feature_depth),
                core.Dense(feature_depth),
                core.Dense(feature_depth),
            ), combinators.Copy()))
    return combinators.Serial(
        combinators.Map(prepare_attention_input),
        ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
        combinators.Map(
            PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
                feature_depth=feature_depth,
                num_heads=num_heads,
                dropout=dropout,
                mode=mode),
            check_shapes=False),
        combinators.Map(combinators.Select(0),
                        check_shapes=False),  # drop masks
        combinators.Map(core.Dense(feature_depth)))
Exemple #2
0
 def test_select_named(self):
     input_shape = {'a': (2, 3), 'b': (3, 4)}
     expected_shape = (3, 4)
     output_shape = base.check_shape_agreement(combinators.Select('b'),
                                               input_shape)
     self.assertEqual(output_shape, expected_shape)
Exemple #3
0
 def test_select(self):
     input_shape = ((2, 3), (3, 4))
     expected_shape = (3, 4)
     output_shape = base.check_shape_agreement(combinators.Select(1),
                                               input_shape)
     self.assertEqual(output_shape, expected_shape)
 def test_select_op_not_defined(self):
     input_shape = ((3, 2), (4, 7))
     with self.assertRaises(AttributeError):
         cb.Select(1, input_shape)