def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale): '''Core relative positional attention operations.''' # content based attention score ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h) # position based attention score bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r) bd = rel_shift(bd, klen=tf.shape(ac)[1]) # segment based attention score if seg_mat is None: ef = 0 else: ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) # merge attention scores and perform masking attn_score = (ac + bd + ef) * scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) return attn_vec
def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, scope='embedding', tilda_embeddings=None, reuse=None, dtype=tf.float32): '''TPU and GPU embedding_lookup function.''' if tilda_embeddings is not None: lookup_table = tilda_embeddings else: with tf.variable_scope(scope, reuse=reuse): lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], dtype=dtype, initializer=initializer) if use_tpu: one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) if one_hot_idx.shape.ndims == 2: return (tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table) else: return (tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table) else: return tf.nn.embedding_lookup(lookup_table, x), lookup_table
def noncausal_denominator(qs, ks): '''Computes FAVOR normalizer in noncausal attention. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. Returns: FAVOR normalizer in noncausal attention. ''' all_ones = tf.ones([ks.shape[0]]) ks_sum = tf.einsum('lbhm,l->bhm', ks, all_ones) return tf.einsum('lbhm,bhm->lbh', qs, ks_sum)
def noncausal_numerator(qs, ks, vs): '''Computes not-normalized FAVOR noncausal attention AV. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. vs: value tensor of the shape [L,B,H,D]. Returns: Not-normalized FAVOR noncausal attention AV. ''' kvs = tf.einsum('lbhm,lbhd->bhmd', ks, vs) return tf.einsum('lbhm,bhmd->lbhd', qs, kvs)
def causal_numerator(qs, ks, vs): '''Computes not-normalized FAVOR causal attention A_{masked}V. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. vs: value tensor of the shape [L,B,H,D]. Returns: Not-normalized FAVOR causal attention A_{masked}V. ''' result = [] sums = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) for index in range(qs.shape[0]): sums = sums + tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) result.append( tf.einsum('ijkl,ijk->ijl', sums, qs[index])[None, Ellipsis]) result = tf.concat(result, axis=0) def grad(res_grad): grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) gr_sums = sums q_grads = [] k_grads = [] v_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None, Ellipsis]) grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index], res_grad[index]) k_grads.append( tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis]) v_grads.append( tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis]) gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) v_grads = tf.concat(v_grads[::-1], axis=0) return q_grads, k_grads, v_grads return result, grad
def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None, tie_weight=False, bi_data=True, use_tpu=False): '''doc.''' with tf.variable_scope('lm_loss'): if tie_weight: assert lookup_table is not None, \ 'lookup_table cannot be None for tie_weight' softmax_w = lookup_table else: softmax_w = tf.get_variable( 'weight', [n_token, d_model], dtype=hidden.dtype, initializer=initializer) softmax_b = tf.get_variable( 'bias', [n_token], dtype=hidden.dtype, initializer=tf.zeros_initializer()) logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b preds = tf.argmax(logits, axis=-1) if use_tpu: one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype) loss = -tf.reduce_sum( tf.nn.log_softmax(logits) * one_hot_target, -1) else: loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logits) return loss, preds
def relu_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.001): '''Computes features for the ReLU-kernel. Computes random features for the ReLU kernel from https://arxiv.org/pdf/2009.14794.pdf. Args: data: input data tensor of the shape [B, L, H, D], where: B - batch dimension, L - attention dimensions, H - heads, D - features. is_query: indicates whether input data is a query oor key tensor. projection_matrix: random Gaussian matrix of shape [M, D], where M stands for the number of random features and each D x D sub-block has pairwise orthogonal rows. numerical_stabilizer: small positive constant for numerical stability. Returns: Corresponding kernel feature map. ''' del is_query if projection_matrix is None: return tf.nn.relu(data) + numerical_stabilizer else: ratio = tf.math.rsqrt(tf.cast(projection_matrix.shape[0], tf.float32)) data_dash = ratio * tf.einsum('blhd,md->blhm', data, projection_matrix) return tf.nn.relu(data_dash) + numerical_stabilizer
def call(self, inputs): ret = tf.einsum(self._einsum_string, inputs, self._kernel) if self._use_bias: ret += self._bias if self._activation is not None: ret = self._activation(ret) return ret
def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training, scale): '''Core absolute positional attention operations.''' attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head) attn_score *= scale if attn_mask is not None: attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head) return attn_vec
def head_projection(h, d_model, n_head, d_head, kernel_initializer, name): '''Project hidden states to a specific head with a 4D-shape.''' proj_weight = tf.get_variable('{}/kernel'.format(name), [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) head = tf.einsum('ibh,hnd->ibnd', h, proj_weight) return head
def positional_embedding(pos_seq, inv_freq, bsz=None): sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) pos_emb = pos_emb[:, None, :] if bsz is not None: pos_emb = tf.tile(pos_emb, [1, bsz, 1]) return pos_emb
def grad(res_grad): grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) gr_sums = sums q_grads = [] k_grads = [] v_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None, Ellipsis]) grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index], res_grad[index]) k_grads.append( tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis]) v_grads.append( tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis]) gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) v_grads = tf.concat(v_grads[::-1], axis=0) return q_grads, k_grads, v_grads
def grad(res_grad): k_grad = tf.zeros_like(ks[0]) gr_sums = sums q_grads = [] k_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None, Ellipsis]) k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index], res_grad[index]) k_grads.append(k_grad[None, Ellipsis]) gr_sums = gr_sums - ks[index] q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) return q_grads, k_grads
def dense_layer_3d(input_tensor, num_attention_heads, head_size, initializer, activation, use_einsum, name=None, trainable=True): """A dense layer with 3D kernel. Args: input_tensor: float Tensor of shape [batch, seq_length, hidden_size]. num_attention_heads: Number of attention heads. head_size: The size per attention head. initializer: Kernel initializer. activation: Actication function. use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers. name: The name scope of this layer. Returns: float logits Tensor. """ input_shape = util.get_shape_list(input_tensor) hidden_size = input_shape[2] with tf.variable_scope(name): w = tf.get_variable( name="kernel", shape=[hidden_size, num_attention_heads * head_size], initializer=initializer, trainable=trainable) w = tf.reshape(w, [hidden_size, num_attention_heads, head_size]) b = tf.get_variable( name="bias", shape=[num_attention_heads * head_size], initializer=tf.zeros_initializer, trainable=trainable) b = tf.reshape(b, [num_attention_heads, head_size]) if use_einsum: ret = tf.einsum("BFH,HND->BFND", input_tensor, w) else: ret = einsum_via_matmul(input_tensor, w, 1) ret += b if activation is not None: return activation(ret) else: return ret
def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, kernel_initializer, residual=True): '''Post-attention processing.''' # post-attention projection (back to `d_model`) proj_o = tf.get_variable( 'o/kernel', [d_model, n_head, d_head], dtype=h.dtype, initializer=kernel_initializer) attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o) attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) if residual: output = util.layer_norm( attn_out + h, name='LayerNorm') else: output = util.layer_norm( attn_out, name='LayerNorm') return output
def dense_layer_2d(input_tensor, output_size, initializer, activation, use_einsum, num_attention_heads=1, name=None, trainable=True): """A dense layer with 2D kernel. Args: input_tensor: Float tensor with rank 3. output_size: The size of output dimension. initializer: Kernel initializer. activation: Activation function. use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers. num_attention_heads: number of attention head in attention layer. name: The name scope of this layer. Returns: float logits Tensor. """ del num_attention_heads # unused input_shape = util.get_shape_list(input_tensor) hidden_size = input_shape[2] with tf.variable_scope(name): w = tf.get_variable(name="kernel", shape=[hidden_size, output_size], initializer=initializer, trainable=trainable) b = tf.get_variable(name="bias", shape=[output_size], initializer=tf.zeros_initializer, trainable=trainable) if use_einsum: ret = tf.einsum("BFH,HO->BFO", input_tensor, w) else: ret = tf.matmul(input_tensor, w) ret += b if activation is not None: return activation(ret) else: return ret
def softmax_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.000001): '''Computes random features for the softmax kernel using FAVOR+ mechanism. Computes random features for the softmax kernel using FAVOR+ mechanism from https://arxiv.org/pdf/2009.14794.pdf. Args: data: input data tensor of the shape [B, L, H, D], where: B - batch dimension, L - attention dimensions, H - heads, D - features. is_query: indicates whether input data is a query oor key tensor. projection_matrix: random Gaussian matrix of shape [M, D], where M stands for the number of random features and each D x D sub-block has pairwise orthogonal rows. numerical_stabilizer: small positive constant for numerical stability. Returns: Corresponding kernel feature map. ''' data_normalizer = \ tf.math.rsqrt(1 / tf.math.rsqrt(tf.cast(data.shape[-1], tf.float32))) ratio = tf.math.rsqrt( tf.cast( projection_matrix.shape[0] if projection_matrix is not None else 1.0, tf.float32)) data_dash = tf.einsum('blhd,md->blhm', data, projection_matrix) diag_data = tf.math.square(data) diag_data = tf.math.reduce_sum(diag_data, axis=tf.keras.backend.ndim(data) - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1) if is_query: last_dims_t = (len(data_dash.shape) - 1, ) data_dash = ratio * (tf.math.exp( data_dash - diag_data - tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer) else: data_dash = ratio * (tf.math.exp(data_dash - diag_data - tf.math.reduce_max(data_dash)) + numerical_stabilizer) return data_dash
def dense_layer_3d_proj(input_tensor, hidden_size, head_size, initializer, activation, use_einsum, name=None): """A dense layer with 3D kernel for projection. Args: input_tensor: float Tensor of shape [batch,from_seq_length, num_attention_heads, size_per_head]. hidden_size: The size of hidden layer. head_size: The size of head. initializer: Kernel initializer. activation: Actication function. use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers. name: The name scope of this layer. Returns: float logits Tensor. """ input_shape = util.get_shape_list(input_tensor) num_attention_heads = input_shape[2] with tf.variable_scope(name): w = tf.get_variable( name="kernel", shape=[num_attention_heads * head_size, hidden_size], initializer=initializer) w = tf.reshape(w, [num_attention_heads, head_size, hidden_size]) b = tf.get_variable(name="bias", shape=[hidden_size], initializer=tf.zeros_initializer) if use_einsum: ret = tf.einsum("BFND,NDH->BFH", input_tensor, w) else: ret = einsum_via_matmul(input_tensor, w, 2) ret += b if activation is not None: return activation(ret) else: return ret
def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h, attn_mask_g, target_mapping, d_model, n_head, d_head, dropout, dropatt, is_training, kernel_initializer, scope='rel_attn'): '''Two-stream attention with relative positional encoding.''' scale = 1 / (d_head**0.5) with tf.variable_scope(scope, reuse=False): # content based attention score if mems is not None and mems.shape.ndims > 1: cat = tf.concat([mems, h], 0) else: cat = h # content-based key head k_head_h = head_projection(cat, d_model, n_head, d_head, kernel_initializer, 'k') # content-based value head v_head_h = head_projection(cat, d_model, n_head, d_head, kernel_initializer, 'v') # position-based key head k_head_r = head_projection(r, d_model, n_head, d_head, kernel_initializer, 'r') ##### h-stream # content-stream query head q_head_h = head_projection(h, d_model, n_head, d_head, kernel_initializer, 'q') # core attention ops attn_vec_h = rel_attn_core(q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale) # post processing output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout, is_training, kernel_initializer) with tf.variable_scope(scope, reuse=True): ##### g-stream # query-stream query head q_head_g = head_projection(g, d_model, n_head, d_head, kernel_initializer, 'q') # core attention ops if target_mapping is not None: q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) attn_vec_g = rel_attn_core(q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) else: attn_vec_g = rel_attn_core(q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) # post processing output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout, is_training, kernel_initializer) return output_h, output_g