def test_branch_named(self): input_shape = (2, 3) expected_shape = {'a': (2, 3), 'b': (2, 3)} output_shape = base.check_shape_agreement( combinators.Branch(a=combinators.NoOp(), b=combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel(self): input_shape = ((2, 3), (2, 3)) expected_shape = ((2, 3), (2, 3)) output_shape = base.check_shape_agreement( combinators.Parallel(combinators.NoOp(), combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)
def MultiHeadedAttentionQKV( feature_depth, num_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form (q, k, v), mask. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return combinators.Serial( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), combinators.NoOp() ), PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), combinators.Parallel(core.Dense(feature_depth), combinators.NoOp()) )
def test_parallel_named(self): input_shape = {'a': (2, 3), 'b': (2, 3)} expected_shape = {'a': (2, 3), 'b': (2, 3)} output_shape = base.check_shape_agreement( combinators.Parallel(a=combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)