Пример #1
0
def vanilla_transformer_layer(x,
                              config,
                              is_training=True,
                              attn_bias=None,
                              layer_idx=0):  # pylint: disable=unused-argument
    """transformer layer: attention + ffn."""
    # Attention
    with tf.variable_scope('attn'):
        shortcut, x = ops.preprocess(x, config)
        x = attention.multihead_attention(x,
                                          x,
                                          x,
                                          config.model_size,
                                          config.num_heads,
                                          is_training=is_training,
                                          dropatt=config.dropatt,
                                          attn_bias=attn_bias,
                                          bias=config.dense_use_bias)
        x = ops.postprocess(shortcut, x, config, is_training)

    # FFN
    with tf.variable_scope('ffn'):
        shortcut, x = ops.preprocess(x, config)
        x = ops.ffn(x, is_training, config.dropout)
        x = ops.postprocess(shortcut, x, config, is_training)

    return x
Пример #2
0
def transformer_approx_att_layer(x,
                                 config,
                                 is_training=True,
                                 attn_bias=None,
                                 attn_impl=None,
                                 layer_idx=0):
    """transformer layer: approximated attention + ffn."""
    # Attention
    with tf.variable_scope('attn'):
        shortcut, x = ops.preprocess(x, config)
        x = attn_impl(x, config, is_training=is_training)
        x = ops.postprocess(shortcut, x, config, is_training)

    # FFN
    with tf.variable_scope('ffn'):
        shortcut, x = ops.preprocess(x, config)
        x = ops.ffn(x, is_training, config.dropout)
        x = ops.postprocess(shortcut, x, config, is_training)

    return x
Пример #3
0
def axial_rowmajor(x, config, is_training=True, causal=True):
    """Full attention matrix with sqrt decomposition."""
    bsize = x.shape[0]
    seq_len = x.shape.as_list()[1]
    head_dim = config.model_size // config.num_heads
    assert seq_len % config.max_seg_len == 0
    num_seg = seq_len // config.max_seg_len
    x_sqr = tf.reshape(x,
                       [bsize, num_seg, config.max_seg_len, config.model_size])
    q_row_local, key_row_local, value_row_local = attention.get_qkv(
        x_sqr,
        x_sqr,
        x_sqr,
        hidden_size=config.model_size,
        num_heads=config.num_heads,
        bias=config.dense_use_bias)
    local_logits = tf.einsum('bsqhd,bskhd->bsqhk', q_row_local, key_row_local)
    row_probs = attention.float32_softmax(local_logits, axis=-1)
    if is_training:
        row_probs = tf.nn.dropout(row_probs, rate=config.dropatt)

    row_attn_out = tf.einsum('bsqhk,bskhd->bsqhd', row_probs, value_row_local)
    if config.row_summary == 'none':
        key_row = key_row_local
    elif config.row_summary in ['wsum', 'proj', 'wsum_proj']:
        if 'wsum' in config.row_summary:
            pre_summary = tf.einsum('bsqhk,bskhd->bsqhd', row_probs,
                                    key_row_local)
        else:
            pre_summary = row_attn_out
        if 'proj' in config.row_summary:
            with tf.variable_scope('rowmajor_param_post'):
                key_row = ops.trail_dense(pre_summary,
                                          config.model_size,
                                          begin_axis=-2,
                                          bias=config.dense_use_bias)
                key_row = ops.postprocess(x_sqr, key_row, config, is_training)
                _, key_row = ops.preprocess(key_row, config)
                key_row = ops.trail_dense(key_row,
                                          [config.num_heads, head_dim],
                                          bias=config.dense_use_bias)
        else:
            key_row = pre_summary
    else:
        raise ValueError('Unknown row summary %s' % config.row_summary)
    if causal:
        local_mask = get_causal_mask(q_row_local, axis=2, is_strict=False)
        local_logits += local_mask[:, tf.newaxis, :]

    global_logits = tf.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row)
    if causal:
        global_mask = get_causal_mask(q_row_local, axis=1, is_strict=True)
        global_logits += global_mask[:, tf.newaxis, tf.newaxis, :]
    # (bsize, num_seg, seg_len, n_head, seg_len + num_seg)
    joint_logits = tf.concat([local_logits, global_logits], axis=-1)
    attn_probs = attention.float32_softmax(joint_logits, axis=-1)
    local_att, global_att = tf.split(attn_probs, [config.max_seg_len, num_seg],
                                     axis=-1)
    if is_training:
        local_att = tf.nn.dropout(local_att, rate=config.dropatt)
    local_merged = tf.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local)
    global_merged = tf.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out)
    joint_merged = tf.reshape(local_merged + global_merged,
                              [bsize, seq_len, config.num_heads, head_dim])
    output = ops.trail_dense(joint_merged,
                             config.model_size,
                             begin_axis=-2,
                             bias=config.dense_use_bias)
    return output