def _CacheableDense(): if cache_KV_in_predict and mode == 'predict': return cb.Cache(core.Dense(d_feature)) else: return core.Dense(d_feature)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See ``Attention`` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V tensors in predict mode. q_sparsity: Sparsity with which to process queries. If None, Dense is used. If 'noop' then no processing is used. result_sparsity: Sparsity with which to process result of the attention. If None, Dense is used. If 'noop' then no processing is used. """ k_processor = core.Dense(d_feature) v_processor = core.Dense(d_feature) if cache_KV_in_predict and mode == 'predict': k_processor = cb.Cache(k_processor) v_processor = cb.Cache(v_processor) if q_sparsity is None: q_processor = core.Dense(d_feature) elif q_sparsity == 'noop': q_processor = cb.Serial() else: d_module = d_feature // q_sparsity q_processor = cb.Serial( sparsity.MultiplicativeSparseDense(q_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(q_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) if result_sparsity is None: result_processor = core.Dense(d_feature) elif result_sparsity == 'noop': result_processor = cb.Serial() else: d_module = d_feature // result_sparsity result_processor = cb.Serial( sparsity.MultiplicativeSparseDense(result_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(result_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) return cb.Serial( cb.Parallel( q_processor, k_processor, v_processor, ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), result_processor)
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): """Returns a layer that maps (q, k, v, mask) to (activations, mask). See :py:class:`Attention` above for further context/details. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. q_sparsity: Sparsity with which to process queries. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. result_sparsity: Sparsity with which to process result of the attention. If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is used. """ k_processor = core.Dense(d_feature) v_processor = core.Dense(d_feature) if cache_KV_in_predict and mode == 'predict': k_processor = cb.Cache(k_processor) v_processor = cb.Cache(v_processor) if q_sparsity is None: q_processor = core.Dense(d_feature) elif q_sparsity == 'noop': q_processor = cb.Serial() else: d_module = d_feature // q_sparsity q_processor = cb.Serial( sparsity.MultiplicativeSparseDense(q_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(q_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) if result_sparsity is None: result_processor = core.Dense(d_feature) elif result_sparsity == 'noop': result_processor = cb.Serial() else: d_module = d_feature // result_sparsity result_processor = cb.Serial( sparsity.MultiplicativeSparseDense(result_sparsity, d_feature, d_feature), sparsity.LocallyConvDense(result_sparsity, d_module, mode=mode, kernel_size=3, length_kernel_size=3)) return cb.Serial( cb.Parallel( q_processor, k_processor, v_processor, ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), result_processor)