Beispiel #1
0
 def _SparsifiableDense(layer_sparsity):
   if layer_sparsity is None:
     return core.Dense(d_feature)
   elif layer_sparsity == 'noop':
     return cb.Serial()  # No-op layer.
   else:
     d_module = d_feature // layer_sparsity
     return cb.Serial(
         sparsity.FactoredDense(layer_sparsity, d_feature, d_feature),
         sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode,
                                   kernel_size=3, length_kernel_size=3)
     )
Beispiel #2
0
def AttentionQKV(d_feature,
                 n_heads=1,
                 dropout=0.0,
                 mode='train',
                 cache_KV_in_predict=False,
                 q_sparsity=None,
                 result_sparsity=None):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See ``Attention`` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V tensors in predict mode.
    q_sparsity: Sparsity with which to process queries. If None, Dense is
        used. If 'noop' then no processing is used.
    result_sparsity: Sparsity with which to process result of the attention. If
        None, Dense is used. If 'noop' then no processing is used.
  """
    k_processor = core.Dense(d_feature)
    v_processor = core.Dense(d_feature)
    if cache_KV_in_predict and mode == 'predict':
        k_processor = cb.Cache(k_processor)
        v_processor = cb.Cache(v_processor)

    if q_sparsity is None:
        q_processor = core.Dense(d_feature)
    elif q_sparsity == 'noop':
        q_processor = cb.Serial()
    else:
        d_module = d_feature // q_sparsity
        q_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(q_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(q_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    if result_sparsity is None:
        result_processor = core.Dense(d_feature)
    elif result_sparsity == 'noop':
        result_processor = cb.Serial()
    else:
        d_module = d_feature // result_sparsity
        result_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(result_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(result_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    return cb.Serial(
        cb.Parallel(
            q_processor,
            k_processor,
            v_processor,
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        result_processor)
Beispiel #3
0
def AttentionQKV(d_feature,
                 n_heads=1,
                 dropout=0.0,
                 mode='train',
                 cache_KV_in_predict=False,
                 q_sparsity=None,
                 result_sparsity=None):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See :py:class:`Attention` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode.
    q_sparsity: Sparsity with which to process queries. If ``None``,
        :py:class:`Dense` is used; if ``'noop'``, no processing is used.
    result_sparsity: Sparsity with which to process result of the attention.
        If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is
        used.
  """
    k_processor = core.Dense(d_feature)
    v_processor = core.Dense(d_feature)
    if cache_KV_in_predict and mode == 'predict':
        k_processor = cb.Cache(k_processor)
        v_processor = cb.Cache(v_processor)

    if q_sparsity is None:
        q_processor = core.Dense(d_feature)
    elif q_sparsity == 'noop':
        q_processor = cb.Serial()
    else:
        d_module = d_feature // q_sparsity
        q_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(q_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(q_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    if result_sparsity is None:
        result_processor = core.Dense(d_feature)
    elif result_sparsity == 'noop':
        result_processor = cb.Serial()
    else:
        d_module = d_feature // result_sparsity
        result_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(result_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(result_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    return cb.Serial(
        cb.Parallel(
            q_processor,
            k_processor,
            v_processor,
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        result_processor)