Example #1
0
def local_within_block_attention(x,
                                 self_attention_bias,
                                 hparams,
                                 attention_type="local_within_block_mask_right",
                                 q_padding="VALID",
                                 kv_padding="VALID"):
  """Local within block self attention."""
  x_new, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
  with tf.variable_scope("local_within_block"):
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x_new, hparams),
        None,
        self_attention_bias,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        attention_type=attention_type,
        block_width=hparams.block_width,
        block_length=hparams.block_length,
        q_padding=q_padding,
        kv_padding=kv_padding,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        name="local_within_block")
    if is_4d:
      y = tf.reshape(y, x_shape)
    return y
Example #2
0
def full_self_attention(x,
                        self_attention_bias,
                        hparams,
                        q_padding="LEFT",
                        kv_padding="LEFT"):
  """Full self-attention layer."""
  x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
  if self_attention_bias is not None:
    self_attention_bias = get_self_attention_bias(x)
  with tf.variable_scope("self_att"):
    y = common_attention.multihead_attention(
        x,
        None,
        self_attention_bias,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        q_padding=q_padding,
        kv_padding=kv_padding,
        name="self_att")
    if is_4d:
      y = tf.reshape(y, [x_shape[0], x_shape[1], x_shape[2], x_shape[3]])
      y.set_shape([None, None, None, hparams.hidden_size])
    return y
Example #3
0
def dilated_attention_1d(x,
                         hparams,
                         attention_type="masked_dilated_1d",
                         q_padding="VALID",
                         kv_padding="VALID",
                         gap_size=2):
  """Dilated 1d self attention."""
  # self-attention
  x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
  with tf.variable_scope("masked_dilated_1d"):
    y = common_attention.multihead_attention(
        x,
        None,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        attention_type=attention_type,
        block_width=hparams.block_width,
        block_length=hparams.block_length,
        q_padding=q_padding,
        kv_padding=kv_padding,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        gap_size=gap_size,
        num_memory_blocks=hparams.num_memory_blocks,
        name="self_attention")
    if is_4d:
      y = tf.reshape(y, x_shape)
      y.set_shape([None, None, None, hparams.hidden_size])
    return y
Example #4
0
def local_attention_1d(x,
                       hparams,
                       attention_type="local_unmasked",
                       q_padding="VALID",
                       kv_padding="VALID"):
  """Local 1d self attention."""
  # self-attention
  x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
  with tf.variable_scope("local_1d_self_att"):
    y = common_attention.multihead_attention(
        x,
        None,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        attention_type=attention_type,
        shared_rel=hparams.shared_rel,
        block_width=hparams.block_width,
        block_length=hparams.block_length,
        q_padding=q_padding,
        kv_padding=kv_padding,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        make_image_summary=False,
        name="self_attention")
    if is_4d:
      y = tf.reshape(y, x_shape)
    return y
Example #5
0
def local_global_attention(x,
                           self_attention_bias,
                           hparams,
                           q_padding="LEFT",
                           kv_padding="LEFT"):
  """Local and global 1d self attention."""
  with tf.variable_scope("self_local_global_att"):
    [x_global, x_local] = tf.split(x, 2, axis=-1)
    split_hidden_size = int(hparams.hidden_size / 2)
    split_heads = int(hparams.num_heads / 2)
    if self_attention_bias is not None:
      self_attention_bias = get_self_attention_bias(x)
    y_global = common_attention.multihead_attention(
        x_global,
        None,
        self_attention_bias,
        hparams.attention_key_channels or split_hidden_size,
        hparams.attention_value_channels or split_hidden_size,
        split_hidden_size,
        split_heads,
        hparams.attention_dropout,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        q_padding=q_padding,
        kv_padding=kv_padding,
        name="global_self_att")
    y_local = common_attention.multihead_attention(
        x_local,
        None,
        None,
        hparams.attention_key_channels or split_hidden_size,
        hparams.attention_value_channels or split_hidden_size,
        split_hidden_size,
        split_heads,
        hparams.attention_dropout,
        attention_type="local_masked",
        block_length=hparams.block_length,
        block_width=hparams.block_width,
        q_filter_width=hparams.q_filter_width,
        kv_filter_width=hparams.kv_filter_width,
        q_padding=q_padding,
        kv_padding=kv_padding,
        name="local_self_att")
    y = tf.concat([y_global, y_local], axis=-1)
    return y
Example #6
0
def encdec_attention_1d(x,
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hparams):
  """Local 1d self attention."""
  x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
  encoder_output, _, _ = maybe_reshape_4d_to_3d(encoder_output)
  with tf.variable_scope("encdec_attention"):
    # Encoder Decoder attention
    y = common_attention.multihead_attention(
        x,
        encoder_output,
        encoder_decoder_attention_bias,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        name="encdec_attention")
  if is_4d:
    y = tf.reshape(y, x_shape)
    y.set_shape([None, None, None, hparams.hidden_size])
  return y
Example #7
0
def evolved_transformer_encoder(encoder_input,
                                encoder_self_attention_bias,
                                hparams,
                                name="encoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None,
                                attn_bias_for_padding=None):
    """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.

  Note: Pad remover is not supported.

  Args:
    encoder_input: a Tensor.
    encoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    hparams: hyperparameters for model.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be passed in,
      which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used for
      pad_remover(efficiency) and to mask out padding in convolutional layers.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not used.
    attn_bias_for_padding: Padded attention bias in case a unidirectional
      encoder is being used where future attention is masked.

  Returns:
    Tensor encoder output.
  """
    del losses

    hidden_state = encoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            attention_bias = encoder_self_attention_bias
            if attn_bias_for_padding is not None:
                attention_bias = attn_bias_for_padding
            padding = common_attention.attention_bias_to_padding(
                attention_bias)
            nonpadding = 1.0 - padding

        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                with tf.variable_scope("gated_linear_unit"):

                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    values = tf.layers.dense(hidden_state, hparams.hidden_size)
                    gates = tf.layers.dense(hidden_state,
                                            hparams.hidden_size,
                                            activation=tf.nn.sigmoid)
                    hidden_state = values * gates

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("conv_branches"):

                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    # Mask padding from conv layers.
                    mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                   [1, 1, hparams.hidden_size])
                    hidden_state *= mask

                    left_output_dim = int(hparams.hidden_size * 4)
                    left_state = tf.layers.dense(hidden_state,
                                                 left_output_dim,
                                                 activation=tf.nn.relu)
                    left_state = tf.nn.dropout(
                        left_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_output_dim = int(hparams.hidden_size / 2)
                    right_state = tf.layers.conv1d(hidden_state,
                                                   right_output_dim,
                                                   3,
                                                   padding="SAME",
                                                   name="standard_conv_3x1",
                                                   activation=tf.nn.relu)
                    right_state = tf.nn.dropout(
                        right_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_state = tf.pad(
                        right_state, [[0, 0], [0, 0],
                                      [0, left_output_dim - right_output_dim]],
                        constant_values=0)
                    hidden_state = left_state + right_state

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    # Mask padding from conv layer.
                    mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                   [1, 1, left_output_dim])
                    hidden_state *= mask

                    separable_conv_9x1 = tf.layers.SeparableConv1D(
                        right_output_dim,
                        9,
                        padding="SAME",
                        name="separable_conv_9x1")
                    hidden_state = separable_conv_9x1.apply(hidden_state)
                    hidden_state = tf.pad(
                        hidden_state,
                        [[0, 0], [0, 0],
                         [0, hparams.hidden_size - right_output_dim]],
                        constant_values=0)

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("self_attention"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("dense_layers"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   int(hparams.hidden_size *
                                                       4),
                                                   activation=tf.nn.relu)
                    hidden_state = tf.nn.dropout(
                        hidden_state, 1 - hparams.layer_prepostprocess_dropout)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   hparams.hidden_size)
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

        # If normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(hidden_state, hparams)
Example #8
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None,
                        attn_bias_for_padding=None):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convolutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses
    attn_bias_for_padding: Padded attention bias in case a unidirectional
      encoder is being used where future attention is masked.

  Returns:
    y: a Tensors
  """
    x = encoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
                                 value=hparams.num_encoder_layers
                                 or hparams.num_hidden_layers)
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
                                 value=hparams.attention_dropout)
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
                                 value={
                                     "use_bias": "false",
                                     "num_heads": hparams.num_heads,
                                     "hidden_size": hparams.hidden_size
                                 })

    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            attention_bias = encoder_self_attention_bias
            if attn_bias_for_padding is not None:
                attention_bias = attn_bias_for_padding
            padding = common_attention.attention_bias_to_padding(
                attention_bias)
            nonpadding = 1.0 - padding
        pad_remover = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            pad_remover = expert_utils.PadRemover(padding)
        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              pad_remover,
                                              conv_padding="SAME",
                                              nonpadding_mask=nonpadding,
                                              losses=losses)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        mlperf_log.transformer_print(
            key=mlperf_log.MODEL_HP_NORM,
            value={"hidden_size": hparams.hidden_size})
        return common_layers.layer_preprocess(x, hparams)