def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form q, k, v, mask. Args: d_feature: int: dimensionality of feature embedding n_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return [ cb.Parallel( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), core.Dense(d_feature), ]
def MultiHeadedAttentionQKV( feature_depth, num_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form (q, k, v), mask. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return combinators.Serial( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), combinators.NoOp() ), PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), combinators.Parallel(core.Dense(feature_depth), combinators.NoOp()) )
def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) rng = backend.random.get_prng(0) params1 = model1.initialize((1, 32), onp.float32, rng) params2 = model2.initialize((1, 32), onp.float32, rng) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def ChunkedCausalMultiHeadedAttention(feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = combinators.Serial( combinators.Branch( combinators.Branch( # q = k = v = first input combinators.Copy(), combinators.Copy(), combinators.Copy()), CausalMask(axis=-2), # pylint: disable=no-value-for-parameter ), combinators.Parallel( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), ), combinators.Copy())) return combinators.Serial( combinators.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter combinators.Map( PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), check_shapes=False), combinators.Map(combinators.Select(0), check_shapes=False), # drop masks combinators.Map(core.Dense(feature_depth)))
def GRUCell(units): """Builds a traditional GRU cell with dense internal transformations. Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 Args: units: Number of hidden units. Returns: A Stax model representing a traditional GRU RNN cell. """ return GeneralGRUCell(candidate_transform=lambda: core.Dense(units=units), memory_transform=cb.NoOp, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh)