Exemplo n.º 1
0
def attention_matching_layer(seq1,
                             seq1_length,
                             seq2,
                             seq2_length,
                             attn_type='diagonal_bilinear',
                             scaled=True,
                             with_sentinel=False):
    """Encodes seq1 conditioned on seq2, e.g., using word-by-word attention."""
    if attn_type == 'bilinear':
        _, _, attn_states = attention.bilinear_attention(
            seq1, seq2, seq2_length, scaled, with_sentinel)
    elif attn_type == 'dot':
        _, _, attn_states = attention.dot_attention(seq1, seq2, seq2_length,
                                                    scaled, with_sentinel)
    elif attn_type == 'diagonal_bilinear':
        _, _, attn_states = attention.diagonal_bilinear_attention(
            seq1, seq2, seq2_length, scaled, with_sentinel)
    elif attn_type == 'mlp':
        _, _, attn_states = attention.mlp_attention(seq1.get_shape()[-1].value,
                                                    tf.nn.relu, seq1, seq2,
                                                    seq2_length, with_sentinel)
    else:
        raise ValueError("Unknown attention type: %s" % attn_type)

    return attn_states
Exemplo n.º 2
0
def self_attention(inputs,
                   lengths,
                   attn_type='bilinear',
                   scaled=True,
                   repr_dim=None,
                   activation=None,
                   with_sentinel=False,
                   name='self_attention',
                   reuse=False):
    with tf.variable_scope(name, reuse):
        if attn_type == 'bilinear':
            attn_states = attention.bilinear_attention(inputs, inputs, lengths,
                                                       scaled,
                                                       with_sentinel)[2]
        elif attn_type == 'dot':
            attn_states = attention.dot_attention(inputs, inputs, lengths,
                                                  scaled, with_sentinel)[2]
        elif attn_type == 'diagonal_bilinear':
            attn_states = attention.diagonal_bilinear_attention(
                inputs, inputs, lengths, scaled, with_sentinel)[2]
        elif attn_type == 'mlp':
            attn_states = attention.mlp_attention(repr_dim, activation, inputs,
                                                  inputs, lengths,
                                                  with_sentinel)[2]
        else:
            raise ValueError("Unknown attention type: %s" % attn_type)

    return attn_states