def test_drop(self): layer = cb.Drop() input_shape = ((3, 2), ) expected_shape = _EMPTY_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape) input_shape = ((3, 2), ) + _REST_OF_STACK expected_shape = _REST_OF_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed causal attention. # TODO(jonni,lukaszkaiser): standardize and improve layer comments. Accepts inputs of the form x and constructs (q, k, v) and causal mask from x. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result. """ return [ cb.Dup(), cb.Parallel([], CausalMask(axis=-2)), # pylint: disable=no-value-for-parameter Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), cb.Parallel([], cb.Drop()), # x ]
def test_drop(self): layer = cb.Drop() input_shape = (3, 2) expected_shape = () output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)