Ejemplo n.º 1
0
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
    # Initialize head_num
    if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
        # s'shape: [bs, seq_len, head_num, head_dim]
        s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
        # s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
        s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
        #s's shape: [bs, num_relation_heads, seq_len,, head_dim_new]
        s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
    if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
        t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
        t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
        t = tensor.transpose(x=t, perm=[0, 2, 1, 3])

    pad_seq_len = s.shape[2]
    s_head_dim, t_head_dim = s.shape[3], t.shape[3]
    scaled_dot_product_s = tensor.matmul(
        x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
    del s
    scaled_dot_product_s += attn_mask

    scaled_dot_product_t = tensor.matmul(
        x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
    del t
    scaled_dot_product_t += attn_mask
    loss = loss_fct(F.log_softmax(scaled_dot_product_s),
                    F.softmax(scaled_dot_product_t))
    return loss
    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        qk_out = layers.matmul(x=q_out,
                               y=k_out,
                               transpose_y=True,
                               alpha=self.head_dim**-0.5)

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
            dropout_out = F.dropout(softmax_out,
                                    self.dropout_prob,
                                    training=self.training,
                                    mode="upscale_in_train")
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
        else:
            final_out = residual_out
        paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
        return final_out, tensor_query.grad
    def compute_kv(self, key, value):
        r"""
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.
        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.
        """
        k = self.k_proj(key)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })

        v = self.v_proj(value)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v
Ejemplo n.º 4
0
    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        r"""
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.

        """
        q = self.q_proj(query)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)

        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)

        return (q, k, v) if use_cache is False else (q, k, v, cache)
Ejemplo n.º 5
0
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
    """
    Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2.
    Args:
        loss_fct (callable):
            Loss function for distillation. It only supports kl_div loss now.
        s (Tensor):
            Q, K, V of Student.
        t (Tensor):
            Q, K, V of teacher.
        attn_mask (Tensor):
            Attention mask for relation.
        num_relation_heads (int):
            The number of relation heads. 0 means `num_relation_heads` equals
            to origin head num.
            Defaults to 0.

    Returns:
        Tensor: MiniLM loss value.

    """
    # Initialize head_num
    if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
        # s'shape: [bs, seq_len, head_num, head_dim]
        s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
        # s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
        s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
        # s' shape: [bs, num_relation_heads, seq_len, head_dim_new]
        s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
    if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
        t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
        t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
        t = tensor.transpose(x=t, perm=[0, 2, 1, 3])

    s_head_dim, t_head_dim = s.shape[3], t.shape[3]
    scaled_dot_product_s = tensor.matmul(
        x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
    del s
    scaled_dot_product_s += attn_mask

    scaled_dot_product_t = tensor.matmul(
        x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
    del t
    scaled_dot_product_t += attn_mask
    loss = loss_fct(F.log_softmax(scaled_dot_product_s),
                    F.softmax(scaled_dot_product_t))
    return loss
Ejemplo n.º 6
0
    def forward(self,
                query,
                key,
                value,
                attn_mask=None,
                use_cache=False,
                cache=None):
        r"""
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache,
                                            cache)
        else:
            q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
                                               cache)
        # scale dot product attention
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)

        # if attn_mask is not None:
        # product = product + attn_mask
        # weights = F.softmax(product)

        weights = incubate.softmax_mask_fuse_upper_triangle(product)

        if self.dropout:
            with get_rng_state_tracker().rng_state('local_seed'):
                weights = F.dropout(weights,
                                    self.dropout,
                                    training=self.training,
                                    mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)
Ejemplo n.º 7
0
    def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
        key = query if key is None else key
        value = query if value is None else value

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        sinusoidal_pos = self.positional_embedding(query)
        if self.rotary_value:
            q, k, v = self.apply_rotary_position_embeddings(
                sinusoidal_pos, q, k, v)
        else:
            q, k = self.apply_rotary_position_embeddings(sinusoidal_pos, q, k)

        product = tensor.matmul(x=q, y=k, transpose_y=True) * self.scale
        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, product.dtype)
            product = product + attn_mask

        weights = F.softmax(product)
        weights = self.dropout(weights)
        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if cache is not None:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)
Ejemplo n.º 8
0
    def compute_kv(self, key, value):
        r"""
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.

        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.

        """
        k = self.k_proj(key)
        v = self.v_proj(value)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v
Ejemplo n.º 9
0
def attention_forward(self,
                      query,
                      key=None,
                      value=None,
                      attn_mask=None,
                      cache=None):
    """
    Redefines the `forward` function of `paddle.nn.MultiHeadAttention`.
    """
    key = query if key is None else key
    value = query if value is None else value
    # Computes q ,k ,v
    if cache is None:
        q, k, v = self._prepare_qkv(query, key, value, cache)
    else:
        q, k, v, cache = self._prepare_qkv(query, key, value, cache)

    # Scale dot product attention
    product = tensor.matmul(x=q, y=k, transpose_y=True)
    product /= math.sqrt(self.head_dim)

    if attn_mask is not None:
        # Support bool or int mask
        attn_mask = _convert_attention_mask(attn_mask, product.dtype)
        product = product + attn_mask

    self.attention_matrix = product if self.return_attentions else None
    weights = F.softmax(product)
    if self.dropout:
        weights = F.dropout(weights,
                            self.dropout,
                            training=self.training,
                            mode="upscale_in_train")

    out = tensor.matmul(weights, v)
    if self.return_qkv:
        self.q = q
        self.k = k
        self.v = v

    # Combine heads
    out = tensor.transpose(out, perm=[0, 2, 1, 3])
    out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

    # Project to output
    out = self.out_proj(out)

    outs = [out]
    if self.need_weights:
        outs.append(weights)
    if cache is not None:
        outs.append(cache)
    return out if len(outs) == 1 else tuple(outs)
Ejemplo n.º 10
0
 def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
     """
     Prapares linear projected queries, keys and values for usage of subsequnt
     multiple parallel attention. If `cache` is not None, using cached results
     to reduce redundant calculations.
     """
     q = self.q_proj(query)
     if _global_parallel_strategy == "mp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [-1, 0]
                           })
     elif _global_parallel_strategy == "dp_mp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [-1, 1]
                           })
     elif _global_parallel_strategy == "mp_pp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh":
                               MPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [-1, 0]
                           })
     elif _global_parallel_strategy == "dp_mp_pp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh":
                               DPMPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [-1, 1]
                           })
     q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
     q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
     if isinstance(cache, self.StaticCache):
         # for encoder-decoder attention in inference and has cached
         k, v = cache.k, cache.v
     else:
         k, v = self.compute_kv(key, value)
     if isinstance(cache, self.Cache):
         # for decoder self-attention in inference
         k = tensor.concat([cache.k, k], axis=2)
         v = tensor.concat([cache.v, v], axis=2)
     if use_cache is True:
         cache = self.Cache(k, v)
     return (q, k, v) if use_cache is False else (q, k, v, cache)
Ejemplo n.º 11
0
    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)

        cache_kvs = []
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None

        for i in range(self.layers):
            residual = tensor_query
            ln1_out = tensor_query
            if self.pre_layer_norm:
                ln1_out = self.norm(tensor_query)

            q = self.q_proj(ln1_out)
            q = tensor.reshape(x=q,
                               shape=[0, 0, self.num_heads, self.head_dim])
            q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
            k = self.k_proj(ln1_out)
            v = self.v_proj(ln1_out)
            k = tensor.reshape(x=k,
                               shape=[0, 0, self.num_heads, self.head_dim])
            k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
            v = tensor.reshape(x=v,
                               shape=[0, 0, self.num_heads, self.head_dim])
            v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

            if self.has_cache_kv:
                # [1, B, n_head, cache_seq_len, head_dim]
                cache_k, cache_v = paddle.split(cache_kv, 2)
                cache_k = paddle.squeeze(cache_k, axis=0)
                cache_v = paddle.squeeze(cache_v, axis=0)
                # [B, n_head, cache_seq_len + seq_len, head_dim]
                # out_seq_len = cache_seq_len + seq_len
                if self.debug:
                    print('q out is')
                    print(q_out[0, 0, :, :])
                    print('cache k out seq=128')
                    print(k_out[0, 0, :, :])
                if self.gen_cache_kv:
                    cache_kvs.append((k_out, v_out))
                else:
                    k_out = paddle.concat([cache_k, k_out], axis=-2)
                    v_out = paddle.concat([cache_v, v_out], axis=-2)

            # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, out_seq_len]
            qk_out = layers.matmul(x=q_out,
                                   y=k_out,
                                   transpose_y=True,
                                   alpha=self.head_dim**-0.5)

            if self.debug:
                print('qk out is')
                print(qk_out[0][0][0])

            if attn_mask is not None:
                attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
                attn_mask_out = qk_out + attn_mask
                if self.debug:
                    print('attn mask out is')
                    print(attn_mask_out[0][0][0])
                softmax_out = F.softmax(attn_mask_out)
            else:
                softmax_out = F.softmax(qk_out)

            if self.debug:
                print('softmax out is')
                print(softmax_out[0][0][0])
            if self.dropout_prob:
                dropout_out = F.dropout(softmax_out,
                                        self.dropout_prob,
                                        training=self.training,
                                        mode="upscale_in_train")
                # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
                # --> [B, n_head, seq_len, head_dim]
                qktv_out = tensor.matmul(dropout_out, v_out)
            else:
                qktv_out = tensor.matmul(softmax_out, v_out)

            fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
            if self.debug:
                print('fmha out is')
                print(fmha_out[0][0][0])
            out_linear_in = tensor.reshape(
                x=fmha_out,
                shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
            out = self.out_proj(out_linear_in)

            residual_out = residual + self.dropout(out)
            if not self.pre_layer_norm:
                attn_out = self.norm(residual_out)
            else:
                attn_out = residual_out

            ffn_ln_out = attn_out
            if self.pre_layer_norm:
                ffn_ln_out = self.ffn_norm(attn_out)

            ffn1_out = self.ffn1_proj(ffn_ln_out)
            ffn1_out = self.dropout(self.activation(ffn1_out))
            ffn2_out = self.ffn2_proj(ffn1_out)

            residual_out = attn_out + self.dropout(ffn2_out)
            final_out = residual_out
            if not self.pre_layer_norm:
                final_out = self.ffn_norm(residual_out)

            tensor_query = final_out

        if self.has_cache_kv and self.gen_cache_kv:
            return final_out, cache_kvs
        return final_out
Ejemplo n.º 12
0
def calc_multi_relation_loss(loss_fct,
                             s,
                             t,
                             attn_mask,
                             num_relation_heads=0,
                             alpha=0.0,
                             beta=0.0):
    """
    Calculates loss for multiple Q-Q, K-K and V-V relation. It supports
    head-head relation, sample-sample relation and origin token-token relation.
    The final loss value could be balanced by weight `alpha` and `beta`.

    Args:
        loss_fct (callable):
            Loss function for distillation. It only supports kl_div loss now.
        s (Tensor):
            Q, K, V of Student.
        t (Tensor):
            Q, K, V of teacher.
        attn_mask (Tensor):
            Attention mask for relation.
        num_relation_heads (int):
            The number of relation heads. 0 means `num_relation_heads` equals
            to origin head num.
            Defaults to 0.
        alpha (float):
            The weight for head-head relation.
            Defaults to 0.0.
        beta (float):
            The weight for sample-sample relation.
            Defaults to 0.0.

    Returns:
        Tensor: Weighted loss of token-token loss, head-head loss and
            sample-sample loss.

    """
    # Initialize head_num
    if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
        # s'shape: [bs, seq_len, head_num, head_dim]
        s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
        # s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
        s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
        s1 = tensor.transpose(x=s, perm=[0, 2, 1, 3])
    if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
        t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
        t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
        t1 = tensor.transpose(x=t, perm=[0, 2, 1, 3])

    s_head_dim, t_head_dim = s.shape[3], t.shape[3]

    if alpha + beta == 1.0:
        loss_token_token = 0.0
    else:
        scaled_dot_product_s1 = tensor.matmul(
            x=s1, y=s1, transpose_y=True) / math.sqrt(s_head_dim)
        del s1
        scaled_dot_product_s1 += attn_mask
        scaled_dot_product_t1 = tensor.matmul(
            x=t1, y=t1, transpose_y=True) / math.sqrt(t_head_dim)
        del t1
        scaled_dot_product_t1 += attn_mask
        loss_token_token = loss_fct(F.log_softmax(scaled_dot_product_s1),
                                    F.softmax(scaled_dot_product_t1))

    if alpha == 0.0:
        loss_head_head = 0.0
    else:
        scaled_dot_product_s = tensor.matmul(
            x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
        attn_mask_head_head = tensor.transpose(x=attn_mask, perm=[0, 3, 1, 2])

        scaled_dot_product_s += attn_mask_head_head
        scaled_dot_product_t = tensor.matmul(
            x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
        scaled_dot_product_t += attn_mask_head_head
        loss_head_head = loss_fct(F.log_softmax(scaled_dot_product_s),
                                  F.softmax(scaled_dot_product_t))
    if beta == 0.0:
        loss_sample_sample = 0.0
    else:
        s2 = tensor.transpose(x=s, perm=[1, 2, 0, 3])
        scaled_dot_product_s2 = tensor.matmul(
            x=s2, y=s2, transpose_y=True) / math.sqrt(s_head_dim)

        del s, s2
        # Shape: [seq_len, 1, batch_size, 1]
        attn_mask_sample_sample = tensor.transpose(x=attn_mask,
                                                   perm=[3, 1, 0, 2])

        # Shape: [seq_len, head_num, batch_size, batch_size]
        scaled_dot_product_s2 += attn_mask_sample_sample
        t2 = tensor.transpose(x=t, perm=[1, 2, 0, 3])
        scaled_dot_product_t2 = tensor.matmul(
            x=t2, y=t2, transpose_y=True) / math.sqrt(t_head_dim)

        del t, t2
        scaled_dot_product_t2 += attn_mask_sample_sample
        loss_sample_sample = loss_fct(F.log_softmax(scaled_dot_product_s2),
                                      F.softmax(scaled_dot_product_t2))

    return (
        1 - alpha - beta
    ) * loss_token_token + alpha * loss_head_head + beta * loss_sample_sample
Ejemplo n.º 13
0
    def forward(self, input_ids, position_ids):
        if _global_parallel_strategy == "dp":
            auto.shard_tensor(input_ids,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(input_ids,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.word_embeddings.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.word_embeddings.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
        target = self.norm1(embeddings)

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
        out0 = self.norm2(residual)

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })

        # Add residual
        final = residual + self.dropout3(out3)
        return final
Ejemplo n.º 14
0
    def forward(self, input):
        if _global_parallel_strategy == "dp":
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
        if _global_parallel_strategy == "mp":
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
        elif _global_parallel_strategy == "dp_mp":
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })

        return out
Ejemplo n.º 15
0
 def forward(self,
             query,
             key,
             value,
             attn_mask=None,
             use_cache=False,
             cache=None):
     """
     Applies multi-head attention to map queries and a set of key-value pairs
     to outputs.
     """
     key = query if key is None else key
     value = query if value is None else value
     # compute q ,k ,v
     if use_cache is False:
         if self.fuse:
             q, k, v = self._fuse_prepare_qkv(query)
         else:
             q, k, v = self._prepare_qkv(query, key, value, use_cache,
                                         cache)
     else:
         q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
                                            cache)
     product = layers.matmul(x=q,
                             y=k,
                             transpose_y=True,
                             alpha=self.head_dim**-0.5)
     if attn_mask is not None:
         product = product + attn_mask
     weights = F.softmax(product)
     if self.dropout:
         weights = F.dropout(weights,
                             self.dropout,
                             training=self.training,
                             mode="upscale_in_train")
     out = tensor.matmul(weights, v)
     # combine heads
     out = tensor.transpose(out, perm=[0, 2, 1, 3])
     out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
     # project to output
     out = self.out_proj(out)
     if _global_parallel_strategy == "mp":
         auto.shard_tensor(self.out_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [0, -1]
                           })
     elif _global_parallel_strategy == "dp_mp":
         auto.shard_tensor(self.out_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [1, -1]
                           })
     elif _global_parallel_strategy == "mp_pp":
         auto.shard_tensor(self.out_proj.weight,
                           dist_attr={
                               "process_mesh":
                               MPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [0, -1]
                           })
     elif _global_parallel_strategy == "dp_mp_pp":
         auto.shard_tensor(self.out_proj.weight,
                           dist_attr={
                               "process_mesh":
                               DPMPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [1, -1]
                           })
     outs = [out]
     if self.need_weights:
         outs.append(weights)
     if use_cache:
         outs.append(cache)
     return out if len(outs) == 1 else tuple(outs)