Exemple #1
0
def _FunnelRelativeDecoderBlock(d_model, d_ff, n_heads, dropout,
                                dropout_shared_axes, mode, ff_activation,
                                total_pooling, shorten_factor, resampler_fn):
  """Returns a list of layers that implements a Transformer decoder block.

  The input is an activation tensor.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
      Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
      a useful way to save memory and apply consistent masks to activation
      vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
      pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
      be an activation-type subclass of `Layer`.
    total_pooling: total pooling.
    shorten_factor: by how much shorten/upsample at this funnel block.
    resampler_fn: Type of function that performs funnel upsampling/downsampling;
      callable with signature: shorten_factor, d_model;  must return an
      activation-type subclass of `Layer`.

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
  resampler = resampler_fn(shorten_factor, d_model)

  attention = RelativeAttentionLMLayer(
      d_model, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

  dropout_ = tl.Dropout(
      rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

  return [
      tl.LayerNorm(),            # h
      tl.Branch(tl.Serial(
          resampler,
          tl.LayerNorm(),
      ), None),                  # h', h
      tl.Residual(
          tl.Select([0, 1, 1]),  # h', h, h
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
  ]
Exemple #2
0
def _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                          mode, ff_activation, context_bias_layer,
                          location_bias_layer, total_pooling):
  """Returns a list of layers that implements a Transformer encoder block.

  The input to the block is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.
    context_bias_layer: Global context bias from Transformer XL's attention.
    location_bias_layer: Global location bias from Transformer XL's attention.
    total_pooling: The combined pool size of previously used funnel blocks.

  Returns:
    A list of layers that maps (activations, att_vecs, mask) to
                               (activations, att_vecs, mask).
  """
  attention = RelativeAttentionLMLayer(
      d_model, context_bias_layer, location_bias_layer,
      total_pooling,
      n_heads=n_heads, dropout=dropout, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

  dropout_ = tl.Dropout(
      rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

  return [
      tl.Residual(               # vecs
          tl.LayerNorm(),
          tl.Select([0, 0, 0]),
          attention,
          dropout_,
      ),                         # vecs
      tl.Residual(
          feed_forward
      ),                         # vecs
  ]
Exemple #3
0
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
                        dropout, dropout_shared_axes, mode, ff_activation,
                        context_bias_layer, location_bias_layer, total_pooling,
                        resampling_fn):
    """Attention resampling."""

    attention = RelativeAttentionLMLayer(d_model,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    feed_forward = FeedForwardBlock(d_model, d_ff, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    resampling = resampling_fn(shorten_factor, d_model, mode=mode)

    def _Dropout():
        return core.Dropout(rate=dropout,
                            shared_axes=dropout_shared_axes,
                            mode=mode)

    return [
        LayerNorm(),  # h
        cb.Branch(cb.Serial(
            resampling,
            LayerNorm(),
        ), None),  # h', h
        cb.Serial(  # pylint: disable=g-long-ternary
            cb.Select([0, 2, 1, 2]),
            cb.Add(),
        ) if is_upsampling else [],
        cb.Residual(
            cb.Select([0, 1, 1]),  # h', h, h
            attention,
            _Dropout(),
        ),
        cb.Residual(
            LayerNorm(),
            feed_forward,
            _Dropout(),
        ),
    ]