def vanilla_transformer_layer(x, config, is_training=True, attn_bias=None, layer_idx=0): # pylint: disable=unused-argument """transformer layer: attention + ffn.""" # Attention with tf.variable_scope('attn'): shortcut, x = ops.preprocess(x, config) x = attention.multihead_attention(x, x, x, config.model_size, config.num_heads, is_training=is_training, dropatt=config.dropatt, attn_bias=attn_bias, bias=config.dense_use_bias) x = ops.postprocess(shortcut, x, config, is_training) # FFN with tf.variable_scope('ffn'): shortcut, x = ops.preprocess(x, config) x = ops.ffn(x, is_training, config.dropout) x = ops.postprocess(shortcut, x, config, is_training) return x
def transformer_approx_att_layer(x, config, is_training=True, attn_bias=None, attn_impl=None, layer_idx=0): """transformer layer: approximated attention + ffn.""" # Attention with tf.variable_scope('attn'): shortcut, x = ops.preprocess(x, config) x = attn_impl(x, config, is_training=is_training) x = ops.postprocess(shortcut, x, config, is_training) # FFN with tf.variable_scope('ffn'): shortcut, x = ops.preprocess(x, config) x = ops.ffn(x, is_training, config.dropout) x = ops.postprocess(shortcut, x, config, is_training) return x
def axial_rowmajor(x, config, is_training=True, causal=True): """Full attention matrix with sqrt decomposition.""" bsize = x.shape[0] seq_len = x.shape.as_list()[1] head_dim = config.model_size // config.num_heads assert seq_len % config.max_seg_len == 0 num_seg = seq_len // config.max_seg_len x_sqr = tf.reshape(x, [bsize, num_seg, config.max_seg_len, config.model_size]) q_row_local, key_row_local, value_row_local = attention.get_qkv( x_sqr, x_sqr, x_sqr, hidden_size=config.model_size, num_heads=config.num_heads, bias=config.dense_use_bias) local_logits = tf.einsum('bsqhd,bskhd->bsqhk', q_row_local, key_row_local) row_probs = attention.float32_softmax(local_logits, axis=-1) if is_training: row_probs = tf.nn.dropout(row_probs, rate=config.dropatt) row_attn_out = tf.einsum('bsqhk,bskhd->bsqhd', row_probs, value_row_local) if config.row_summary == 'none': key_row = key_row_local elif config.row_summary in ['wsum', 'proj', 'wsum_proj']: if 'wsum' in config.row_summary: pre_summary = tf.einsum('bsqhk,bskhd->bsqhd', row_probs, key_row_local) else: pre_summary = row_attn_out if 'proj' in config.row_summary: with tf.variable_scope('rowmajor_param_post'): key_row = ops.trail_dense(pre_summary, config.model_size, begin_axis=-2, bias=config.dense_use_bias) key_row = ops.postprocess(x_sqr, key_row, config, is_training) _, key_row = ops.preprocess(key_row, config) key_row = ops.trail_dense(key_row, [config.num_heads, head_dim], bias=config.dense_use_bias) else: key_row = pre_summary else: raise ValueError('Unknown row summary %s' % config.row_summary) if causal: local_mask = get_causal_mask(q_row_local, axis=2, is_strict=False) local_logits += local_mask[:, tf.newaxis, :] global_logits = tf.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row) if causal: global_mask = get_causal_mask(q_row_local, axis=1, is_strict=True) global_logits += global_mask[:, tf.newaxis, tf.newaxis, :] # (bsize, num_seg, seg_len, n_head, seg_len + num_seg) joint_logits = tf.concat([local_logits, global_logits], axis=-1) attn_probs = attention.float32_softmax(joint_logits, axis=-1) local_att, global_att = tf.split(attn_probs, [config.max_seg_len, num_seg], axis=-1) if is_training: local_att = tf.nn.dropout(local_att, rate=config.dropatt) local_merged = tf.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local) global_merged = tf.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out) joint_merged = tf.reshape(local_merged + global_merged, [bsize, seq_len, config.num_heads, head_dim]) output = ops.trail_dense(joint_merged, config.model_size, begin_axis=-2, bias=config.dense_use_bias) return output