def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: # s'shape: [bs, seq_len, head_num, head_dim] s = tensor.transpose(x=s, perm=[0, 2, 1, 3]) # s'shape: [bs, seq_len, num_relation_heads, head_dim_new] s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1]) #s's shape: [bs, num_relation_heads, seq_len,, head_dim_new] s = tensor.transpose(x=s, perm=[0, 2, 1, 3]) if num_relation_heads > 0 and num_relation_heads != t.shape[1]: t = tensor.transpose(x=t, perm=[0, 2, 1, 3]) t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1]) t = tensor.transpose(x=t, perm=[0, 2, 1, 3]) pad_seq_len = s.shape[2] s_head_dim, t_head_dim = s.shape[3], t.shape[3] scaled_dot_product_s = tensor.matmul( x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim) del s scaled_dot_product_s += attn_mask scaled_dot_product_t = tensor.matmul( x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim) del t scaled_dot_product_t += attn_mask loss = loss_fct(F.log_softmax(scaled_dot_product_s), F.softmax(scaled_dot_product_t)) return loss
def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: attn_mask = None residual = tensor_query ln1_out = tensor_query if self.pre_layer_norm: ln1_out = self.norm1(tensor_query) q = self.q_proj(ln1_out) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(ln1_out) v = self.v_proj(ln1_out) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) qk_out = layers.matmul(x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) attn_mask_out = qk_out + attn_mask softmax_out = F.softmax(attn_mask_out) else: softmax_out = F.softmax(qk_out) if self.dropout_prob: dropout_out = F.dropout(softmax_out, self.dropout_prob, training=self.training, mode="upscale_in_train") qktv_out = tensor.matmul(dropout_out, v_out) else: qktv_out = tensor.matmul(softmax_out, v_out) fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) out_linear_in = tensor.reshape( x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) out = self.out_proj(out_linear_in) residual_out = residual + self.dropout(out) if not self.pre_layer_norm: final_out = self.norm1(residual_out) else: final_out = residual_out paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)], retain_graph=True) return final_out, tensor_query.grad
def compute_kv(self, key, value): r""" Applies linear projection on input keys and values, then splits heads (reshape and transpose) to get keys and values from different representation subspaces. The results are used as key-values pairs for subsequent multiple parallel attention. It is part of calculations in multi-head attention, and is provided as a method to pre-compute and prefetch these results, thus we can use them to construct cache for inference. """ k = self.k_proj(key) if _global_parallel_strategy == "mp": auto.shard_tensor( self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) v = self.v_proj(value) if _global_parallel_strategy == "mp": auto.shard_tensor( self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) return k, v
def _prepare_qkv(self, query, key, value, use_cache=False, cache=None): r""" Prapares linear projected queries, keys and values for usage of subsequnt multiple parallel attention. If `cache` is not None, using cached results to reduce redundant calculations. """ q = self.q_proj(query) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) if isinstance(cache, self.StaticCache): # for encoder-decoder attention in inference and has cached k, v = cache.k, cache.v else: k, v = self.compute_kv(key, value) if isinstance(cache, self.Cache): # for decoder self-attention in inference k = tensor.concat([cache.k, k], axis=2) v = tensor.concat([cache.v, v], axis=2) if use_cache is True: cache = self.Cache(k, v) return (q, k, v) if use_cache is False else (q, k, v, cache)
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): """ Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. s (Tensor): Q, K, V of Student. t (Tensor): Q, K, V of teacher. attn_mask (Tensor): Attention mask for relation. num_relation_heads (int): The number of relation heads. 0 means `num_relation_heads` equals to origin head num. Defaults to 0. Returns: Tensor: MiniLM loss value. """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: # s'shape: [bs, seq_len, head_num, head_dim] s = tensor.transpose(x=s, perm=[0, 2, 1, 3]) # s'shape: [bs, seq_len, num_relation_heads, head_dim_new] s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1]) # s' shape: [bs, num_relation_heads, seq_len, head_dim_new] s = tensor.transpose(x=s, perm=[0, 2, 1, 3]) if num_relation_heads > 0 and num_relation_heads != t.shape[1]: t = tensor.transpose(x=t, perm=[0, 2, 1, 3]) t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1]) t = tensor.transpose(x=t, perm=[0, 2, 1, 3]) s_head_dim, t_head_dim = s.shape[3], t.shape[3] scaled_dot_product_s = tensor.matmul( x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim) del s scaled_dot_product_s += attn_mask scaled_dot_product_t = tensor.matmul( x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim) del t scaled_dot_product_t += attn_mask loss = loss_fct(F.log_softmax(scaled_dot_product_s), F.softmax(scaled_dot_product_t)) return loss
def forward(self, query, key, value, attn_mask=None, use_cache=False, cache=None): r""" Applies multi-head attention to map queries and a set of key-value pairs to outputs. """ key = query if key is None else key value = query if value is None else value # compute q ,k ,v if use_cache is False: if self.fuse: q, k, v = self._fuse_prepare_qkv(query) else: q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) # scale dot product attention product = layers.matmul(x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) # if attn_mask is not None: # product = product + attn_mask # weights = F.softmax(product) weights = incubate.softmax_mask_fuse_upper_triangle(product) if self.dropout: with get_rng_state_tracker().rng_state('local_seed'): weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) outs = [out] if self.need_weights: outs.append(weights) if use_cache: outs.append(cache) return out if len(outs) == 1 else tuple(outs)
def forward(self, query, key=None, value=None, attn_mask=None, cache=None): key = query if key is None else key value = query if value is None else value q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) sinusoidal_pos = self.positional_embedding(query) if self.rotary_value: q, k, v = self.apply_rotary_position_embeddings( sinusoidal_pos, q, k, v) else: q, k = self.apply_rotary_position_embeddings(sinusoidal_pos, q, k) product = tensor.matmul(x=q, y=k, transpose_y=True) * self.scale if attn_mask is not None: attn_mask = _convert_attention_mask(attn_mask, product.dtype) product = product + attn_mask weights = F.softmax(product) weights = self.dropout(weights) out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) outs = [out] if self.need_weights: outs.append(weights) if cache is not None: outs.append(cache) return out if len(outs) == 1 else tuple(outs)
def compute_kv(self, key, value): r""" Applies linear projection on input keys and values, then splits heads (reshape and transpose) to get keys and values from different representation subspaces. The results are used as key-values pairs for subsequent multiple parallel attention. It is part of calculations in multi-head attention, and is provided as a method to pre-compute and prefetch these results, thus we can use them to construct cache for inference. """ k = self.k_proj(key) v = self.v_proj(value) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) return k, v
def attention_forward(self, query, key=None, value=None, attn_mask=None, cache=None): """ Redefines the `forward` function of `paddle.nn.MultiHeadAttention`. """ key = query if key is None else key value = query if value is None else value # Computes q ,k ,v if cache is None: q, k, v = self._prepare_qkv(query, key, value, cache) else: q, k, v, cache = self._prepare_qkv(query, key, value, cache) # Scale dot product attention product = tensor.matmul(x=q, y=k, transpose_y=True) product /= math.sqrt(self.head_dim) if attn_mask is not None: # Support bool or int mask attn_mask = _convert_attention_mask(attn_mask, product.dtype) product = product + attn_mask self.attention_matrix = product if self.return_attentions else None weights = F.softmax(product) if self.dropout: weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) if self.return_qkv: self.q = q self.k = k self.v = v # Combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # Project to output out = self.out_proj(out) outs = [out] if self.need_weights: outs.append(weights) if cache is not None: outs.append(cache) return out if len(outs) == 1 else tuple(outs)
def _prepare_qkv(self, query, key, value, use_cache=False, cache=None): """ Prapares linear projected queries, keys and values for usage of subsequnt multiple parallel attention. If `cache` is not None, using cached results to reduce redundant calculations. """ q = self.q_proj(query) if _global_parallel_strategy == "mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) elif _global_parallel_strategy == "mp_pp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 1] }) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) if isinstance(cache, self.StaticCache): # for encoder-decoder attention in inference and has cached k, v = cache.k, cache.v else: k, v = self.compute_kv(key, value) if isinstance(cache, self.Cache): # for decoder self-attention in inference k = tensor.concat([cache.k, k], axis=2) v = tensor.concat([cache.v, v], axis=2) if use_cache is True: cache = self.Cache(k, v) return (q, k, v) if use_cache is False else (q, k, v, cache)
def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) cache_kvs = [] cache_kv = None if self.has_cache_kv: cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: attn_mask = None for i in range(self.layers): residual = tensor_query ln1_out = tensor_query if self.pre_layer_norm: ln1_out = self.norm(tensor_query) q = self.q_proj(ln1_out) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(ln1_out) v = self.v_proj(ln1_out) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) if self.has_cache_kv: # [1, B, n_head, cache_seq_len, head_dim] cache_k, cache_v = paddle.split(cache_kv, 2) cache_k = paddle.squeeze(cache_k, axis=0) cache_v = paddle.squeeze(cache_v, axis=0) # [B, n_head, cache_seq_len + seq_len, head_dim] # out_seq_len = cache_seq_len + seq_len if self.debug: print('q out is') print(q_out[0, 0, :, :]) print('cache k out seq=128') print(k_out[0, 0, :, :]) if self.gen_cache_kv: cache_kvs.append((k_out, v_out)) else: k_out = paddle.concat([cache_k, k_out], axis=-2) v_out = paddle.concat([cache_v, v_out], axis=-2) # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] # --> [B, n_head, seq_len, out_seq_len] qk_out = layers.matmul(x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) if self.debug: print('qk out is') print(qk_out[0][0][0]) if attn_mask is not None: attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) attn_mask_out = qk_out + attn_mask if self.debug: print('attn mask out is') print(attn_mask_out[0][0][0]) softmax_out = F.softmax(attn_mask_out) else: softmax_out = F.softmax(qk_out) if self.debug: print('softmax out is') print(softmax_out[0][0][0]) if self.dropout_prob: dropout_out = F.dropout(softmax_out, self.dropout_prob, training=self.training, mode="upscale_in_train") # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] # --> [B, n_head, seq_len, head_dim] qktv_out = tensor.matmul(dropout_out, v_out) else: qktv_out = tensor.matmul(softmax_out, v_out) fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) if self.debug: print('fmha out is') print(fmha_out[0][0][0]) out_linear_in = tensor.reshape( x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) out = self.out_proj(out_linear_in) residual_out = residual + self.dropout(out) if not self.pre_layer_norm: attn_out = self.norm(residual_out) else: attn_out = residual_out ffn_ln_out = attn_out if self.pre_layer_norm: ffn_ln_out = self.ffn_norm(attn_out) ffn1_out = self.ffn1_proj(ffn_ln_out) ffn1_out = self.dropout(self.activation(ffn1_out)) ffn2_out = self.ffn2_proj(ffn1_out) residual_out = attn_out + self.dropout(ffn2_out) final_out = residual_out if not self.pre_layer_norm: final_out = self.ffn_norm(residual_out) tensor_query = final_out if self.has_cache_kv and self.gen_cache_kv: return final_out, cache_kvs return final_out
def calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0): """ Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight `alpha` and `beta`. Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. s (Tensor): Q, K, V of Student. t (Tensor): Q, K, V of teacher. attn_mask (Tensor): Attention mask for relation. num_relation_heads (int): The number of relation heads. 0 means `num_relation_heads` equals to origin head num. Defaults to 0. alpha (float): The weight for head-head relation. Defaults to 0.0. beta (float): The weight for sample-sample relation. Defaults to 0.0. Returns: Tensor: Weighted loss of token-token loss, head-head loss and sample-sample loss. """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: # s'shape: [bs, seq_len, head_num, head_dim] s = tensor.transpose(x=s, perm=[0, 2, 1, 3]) # s'shape: [bs, seq_len, num_relation_heads, head_dim_new] s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1]) s1 = tensor.transpose(x=s, perm=[0, 2, 1, 3]) if num_relation_heads > 0 and num_relation_heads != t.shape[1]: t = tensor.transpose(x=t, perm=[0, 2, 1, 3]) t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1]) t1 = tensor.transpose(x=t, perm=[0, 2, 1, 3]) s_head_dim, t_head_dim = s.shape[3], t.shape[3] if alpha + beta == 1.0: loss_token_token = 0.0 else: scaled_dot_product_s1 = tensor.matmul( x=s1, y=s1, transpose_y=True) / math.sqrt(s_head_dim) del s1 scaled_dot_product_s1 += attn_mask scaled_dot_product_t1 = tensor.matmul( x=t1, y=t1, transpose_y=True) / math.sqrt(t_head_dim) del t1 scaled_dot_product_t1 += attn_mask loss_token_token = loss_fct(F.log_softmax(scaled_dot_product_s1), F.softmax(scaled_dot_product_t1)) if alpha == 0.0: loss_head_head = 0.0 else: scaled_dot_product_s = tensor.matmul( x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim) attn_mask_head_head = tensor.transpose(x=attn_mask, perm=[0, 3, 1, 2]) scaled_dot_product_s += attn_mask_head_head scaled_dot_product_t = tensor.matmul( x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim) scaled_dot_product_t += attn_mask_head_head loss_head_head = loss_fct(F.log_softmax(scaled_dot_product_s), F.softmax(scaled_dot_product_t)) if beta == 0.0: loss_sample_sample = 0.0 else: s2 = tensor.transpose(x=s, perm=[1, 2, 0, 3]) scaled_dot_product_s2 = tensor.matmul( x=s2, y=s2, transpose_y=True) / math.sqrt(s_head_dim) del s, s2 # Shape: [seq_len, 1, batch_size, 1] attn_mask_sample_sample = tensor.transpose(x=attn_mask, perm=[3, 1, 0, 2]) # Shape: [seq_len, head_num, batch_size, batch_size] scaled_dot_product_s2 += attn_mask_sample_sample t2 = tensor.transpose(x=t, perm=[1, 2, 0, 3]) scaled_dot_product_t2 = tensor.matmul( x=t2, y=t2, transpose_y=True) / math.sqrt(t_head_dim) del t, t2 scaled_dot_product_t2 += attn_mask_sample_sample loss_sample_sample = loss_fct(F.log_softmax(scaled_dot_product_s2), F.softmax(scaled_dot_product_t2)) return ( 1 - alpha - beta ) * loss_token_token + alpha * loss_head_head + beta * loss_sample_sample
def forward(self, input_ids, position_ids): if _global_parallel_strategy == "dp": auto.shard_tensor(input_ids, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(input_ids, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) if _global_parallel_strategy == "mp": auto.shard_tensor(self.word_embeddings.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.word_embeddings.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [1, -1] }) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) # Pre-norm target = self.norm1(embeddings) # The following is the attention part q = self.q_proj(target) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(target) v = self.v_proj(target) if _global_parallel_strategy == "mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) auto.shard_tensor(self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) auto.shard_tensor(self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) auto.shard_tensor(self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) auto.shard_tensor(self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) # scale dot product attention product = layers.matmul(x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if self.attn_mask is not None: product = product + self.attn_mask weights = F.softmax(product) if self.dropout_ratio: weights = F.dropout(weights, self.dropout_ratio, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [1, -1] }) # Add residual residual = embeddings + self.dropout2(out) # Pre-norm out0 = self.norm2(residual) # The following is the MLP part out1 = self.linear0(out0) out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) if _global_parallel_strategy == "mp": auto.shard_tensor(self.linear0.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) auto.shard_tensor(self.linear1.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.linear0.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) auto.shard_tensor(self.linear1.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [1, -1] }) # Add residual final = residual + self.dropout3(out3) return final
def forward(self, input): if _global_parallel_strategy == "dp": auto.shard_tensor(input, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(input, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1, -1] }) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(input) v = self.v_proj(input) if _global_parallel_strategy == "mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) auto.shard_tensor(self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) auto.shard_tensor(self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 0] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.q_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) auto.shard_tensor(self.k_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) auto.shard_tensor(self.v_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, 1] }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) # scale dot product attention product = layers.matmul(x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if self.attn_mask is not None: product = product + self.attn_mask weights = F.softmax(product) if self.dropout_ratio: weights = F.dropout(weights, self.dropout_ratio, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [1, -1] }) return out
def forward(self, query, key, value, attn_mask=None, use_cache=False, cache=None): """ Applies multi-head attention to map queries and a set of key-value pairs to outputs. """ key = query if key is None else key value = query if value is None else value # compute q ,k ,v if use_cache is False: if self.fuse: q, k, v = self._fuse_prepare_qkv(query) else: q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) product = layers.matmul(x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: product = product + attn_mask weights = F.softmax(product) if self.dropout: weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [1, -1] }) elif _global_parallel_strategy == "mp_pp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [0, -1] }) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.out_proj.weight, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [1, -1] }) outs = [out] if self.need_weights: outs.append(weights) if use_cache: outs.append(cache) return out if len(outs) == 1 else tuple(outs)