def ChunkedCausalMultiHeadedAttention(feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = combinators.Serial( combinators.Branch( combinators.Branch( # q = k = v = first input combinators.Copy(), combinators.Copy(), combinators.Copy()), CausalMask(axis=-2), # pylint: disable=no-value-for-parameter ), combinators.Parallel( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), ), combinators.Copy())) return combinators.Serial( combinators.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter combinators.Map( PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), check_shapes=False), combinators.Map(combinators.Select(0), check_shapes=False), # drop masks combinators.Map(core.Dense(feature_depth)))
def test_select_named(self): input_shape = {'a': (2, 3), 'b': (3, 4)} expected_shape = (3, 4) output_shape = base.check_shape_agreement(combinators.Select('b'), input_shape) self.assertEqual(output_shape, expected_shape)
def test_select(self): input_shape = ((2, 3), (3, 4)) expected_shape = (3, 4) output_shape = base.check_shape_agreement(combinators.Select(1), input_shape) self.assertEqual(output_shape, expected_shape)
def test_select_op_not_defined(self): input_shape = ((3, 2), (4, 7)) with self.assertRaises(AttributeError): cb.Select(1, input_shape)