Exemplo n.º 1
0
def attn_over_sent_and_lex_2d(x_slices, pad_remover_combined, hparams):
    with tf.variable_scope("self_attention"):
        query_antecedent = common_layers.layer_preprocess(x_slices, hparams)
        y_slices = common_attention.multihead_attention_2d(
            query_antecedent=query_antecedent,
            memory_antecedent=None,
            total_key_depth=hparams.attention_key_channels
            or hparams.hidden_size,
            total_value_depth=hparams.attention_value_channels
            or hparams.hidden_size,
            output_depth=hparams.hidden_size,
            num_heads=hparams.num_heads,
            query_shape=(4, 4),
            memory_flange=(4, 4))
        x_slices = common_layers.layer_postprocess(x_slices, y_slices, hparams)
    with tf.variable_scope("ffn"):
        x0_slices = common_layers.layer_preprocess(x_slices, hparams)
        x0_slices, batch_size, sent_len, lex_cap, hid_dim = reshape_2d(
            x0_slices)
        y_slices = transformer.transformer_ffn_layer(x0_slices, hparams,
                                                     pad_remover_combined)
        y_slices = tf.reshape(y_slices,
                              [batch_size, sent_len, lex_cap, hid_dim])
        x_slices = common_layers.layer_postprocess(x_slices, y_slices, hparams)
    return x_slices
def local_attention_2d(x, hparams, attention_type="local_attention_2d"):
  """Local 2d, self attention layer."""
  # self-attention
  with tf.variable_scope("local_2d_self_att"):
    y = common_attention.multihead_attention_2d(
        x,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        attention_type=attention_type,
        query_shape=hparams.query_shape,
        memory_flange=hparams.memory_flange,
        name="self_attention")
  return y
def local_attention_2d(x, hparams, attention_type="local_attention_2d"):
    """Local 2d, self attention layer."""
    # self-attention
    with tf.variable_scope("local_2d_self_att"):
        y = common_attention.multihead_attention_2d(
            x,
            None,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.num_heads,
            attention_type=attention_type,
            query_shape=hparams.query_shape,
            memory_flange=hparams.memory_flange,
            name="self_attention")
    return y
Exemplo n.º 4
0
def attn_over_sent_and_lex_2d_dec(x, encoder_output,
                                  decoder_self_attention_bias, hparams):
    with tf.variable_scope("self_attention"):
        query_antecedent = common_layers.layer_preprocess(x, hparams)
        y = common_attention.multihead_attention(
            query_antecedent=query_antecedent,
            memory_antecedent=None,
            bias=decoder_self_attention_bias,
            total_key_depth=hparams.attention_key_channels
            or hparams.hidden_size,
            total_value_depth=hparams.attention_value_channels
            or hparams.hidden_size,
            output_depth=hparams.hidden_size,
            num_heads=hparams.num_heads,
            dropout_rate=hparams.attention_dropout,
            attention_type=hparams.self_attention_type,
            max_relative_position=hparams.max_relative_position)
        x = common_layers.layer_postprocess(x, y, hparams)
    if encoder_output is not None:
        with tf.variable_scope("encdec_attention"):
            query_antecedent = common_layers.layer_preprocess(x, hparams)

            batch_size = tf.shape(encoder_output)[0]
            src_len = tf.shape(encoder_output)[1]
            tgt_len = tf.shape(query_antecedent)[1]
            lex_cap = encoder_output.shape.as_list()[2]
            hid_size = encoder_output.shape.as_list()[3]

            query_antecedent = tf.expand_dims(query_antecedent, 2)
            query_antecedent = tf.pad(
                query_antecedent, [[0, 0], [0, 0], [0, lex_cap - 1], [0, 0]])
            query_pad = tf.zeros([batch_size, src_len, lex_cap, hid_size])
            query_antecedent = tf.concat([query_antecedent, query_pad], 1)

            memory_antecedent = encoder_output
            memory_pad = tf.zeros([batch_size, tgt_len, lex_cap, hid_size])
            memory_antecedent = tf.concat([memory_antecedent, memory_pad], 1)

            tf.logging.info(
                "dimension of decoder input at the enc-dec attention layer: {0}"
                .format(query_antecedent.get_shape()))
            tf.logging.info(
                "dimension of encoder output at the enc-dec attention layer: {0}"
                .format(memory_antecedent.get_shape()))

            y = common_attention.multihead_attention_2d(
                query_antecedent=query_antecedent,
                memory_antecedent=memory_antecedent,
                total_key_depth=hparams.attention_key_channels
                or hparams.hidden_size,
                total_value_depth=hparams.attention_value_channels
                or hparams.hidden_size,
                output_depth=hparams.hidden_size,
                num_heads=hparams.num_heads,
                attention_type="masked_local_attention_2d",
                query_shape=(4, 4),
                memory_flange=(4, 4))

            tf.logging.info("dimension of enc-dec output: {0}".format(
                y.get_shape()))
            y = y[:, :, 0, :]
            y = y[:, :tgt_len, :]

            x = common_layers.layer_postprocess(x, y, hparams)
    with tf.variable_scope("ffn"):
        x0 = common_layers.layer_preprocess(x, hparams)
        y = transformer.transformer_ffn_layer(x0, hparams)
        x = common_layers.layer_postprocess(x, y, hparams)
    return x