Ejemplo n.º 1
0
 def sparse_attention(self, query_layer, key_layer, value_layer,
                      attention_mask):
     # TODO: sparse attn dropout?
     # TODO: pad to block size
     # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn]
     query_layer, key_layer, value_layer = map(
         lambda t: t.permute(1, 2, 0, 3).contiguous(),
         (query_layer, key_layer, value_layer),
     )
     # output shape [b, np(heads), sq, hn]
     attn_mask = attention_mask.to(query_layer.dtype) * -10000
     if exists(self.rpe):
         rpe = self.rpe(query_layer.size(0), key_layer.size(0))
     else:
         rpe = None
     return self.sparse_attn(query_layer,
                             key_layer,
                             value_layer,
                             attn_mask=attn_mask,
                             rpe=rpe)
Ejemplo n.º 2
0
    def forward(self, hidden_states, attention_mask, rotary_pos_emb=None, layer_past=None,
                get_key_value=False):
        # hidden_states: [sq, b, h]

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        checkpoint_version = get_checkpoint_version()
        if checkpoint_version is not None:
            if checkpoint_version == 0:
                # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
            elif checkpoint_version == 1.0:
                # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
                           (self.num_attention_heads_per_partition,
                            3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        (query_layer,
         key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)

        if exists(rotary_pos_emb):
            query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
                                   key_layer), dim=0)
            value_layer = torch.cat((past_value.type_as(value_layer),
                                     value_layer), dim=0)
        if get_key_value:
            present = (key_layer, value_layer)

        if not self.sparse:
            # ===================================
            # Raw attention scores. [b, np, s, s]
            # ===================================

            # [b, np, sq, sk]
            output_size = (query_layer.size(1),
                           query_layer.size(2),
                           query_layer.size(0),
                           key_layer.size(0))

            # [sq, b, np, hn] -> [sq, b * np, hn]
            query_layer = query_layer.view(output_size[2],
                                           output_size[0] * output_size[1], -1)
            key_layer = key_layer.view(output_size[3],
                                       output_size[0] * output_size[1], -1)

            # preallocating result tensor: [b * np, sq, sk]
            matmul_result = torch.empty(
                output_size[0] * output_size[1],
                output_size[2],
                output_size[3],
                dtype=query_layer.dtype,
                device=torch.cuda.current_device())

            # Raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(matmul_result,
                                          query_layer.transpose(0, 1),  # [b * np, sq, hn]
                                          key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                                          beta=0.0, alpha=(1.0 / self.norm_factor))

            # change view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

            # ==================================================
            # Update attention mask for inference. [b, np, sq, sk]
            # ==================================================

            if get_key_value:
                with torch.no_grad():
                    if layer_past is not None:
                        attention_mask = attention_mask[
                                         ...,
                                         attention_scores.size(3) - 1,
                                         :attention_scores.size(3)].unsqueeze(2)
                    else:
                        attention_mask = attention_mask[
                                         ...,
                                         :attention_scores.size(3),
                                         :attention_scores.size(3)]

            # ===========================
            # Attention probs and dropout
            # ===========================

            if exists(self.rpe):
                rpe = self.rpe(query_layer.size(0), key_layer.size(0))
                attention_scores += rpe  # [1, np, sq, sk]

            # attention scores and attention mask [b, np, sq, sk]
            attention_probs = self.scale_mask_softmax(attention_scores,
                                                      attention_mask)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)

            # =========================
            # Context layer. [sq, b, hp]
            # =========================

            # value_layer -> context layer.
            # [sk, b, np, hn] --> [b, np, sq, hn]

            # context layer shape: [b, np, sq, hn]
            output_size = (value_layer.size(1),
                           value_layer.size(2),
                           query_layer.size(0),
                           value_layer.size(3))

            # change view [sk, b * np, hn]
            value_layer = value_layer.view(value_layer.size(0),
                                           output_size[0] * output_size[1], -1)

            # change view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                                   output_size[2], -1)

            # matmul: [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

            # change view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
        else:
            # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn]
            query_layer, key_layer, value_layer = map(lambda t: t.permute(1, 2, 0, 3).contiguous(),
                                                      (query_layer, key_layer,
                                                       value_layer))
            # output shape [b, np(heads), sq, hn]
            attn_mask = attention_mask.to(query_layer.dtype) * -10000
            if exists(self.rpe):
                rpe = self.rpe(query_layer.size(0), key_layer.size(0))
            else:
                rpe = None
            context_layer = self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
                                  (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.dense(context_layer)

        if get_key_value:
            output = [output, present]

        return output, bias
Ejemplo n.º 3
0
    def forward(self, hidden_states, attention_mask, layer_past=None):

        # hidden_states: [sq, b, h]

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads_per_partition,
            3 * self.hidden_size_per_attention_head,
        )
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        (query_layer, key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)

        if exists(self.rotary_emb):
            if exists(self.rotary_ndims):
                # partial rotary
                query_rot, query_pass = (
                    query_layer[..., :self.rotary_ndims],
                    query_layer[..., self.rotary_ndims:],
                )
                key_rot, key_pass = (
                    key_layer[..., :self.rotary_ndims],
                    key_layer[..., self.rotary_ndims:],
                )
            else:
                # full rotary
                query_rot, key_rot = query_layer, key_layer
            apply_rotary_fn = (apply_rotary_pos_emb_torch
                               if self.bf16 else apply_rotary_pos_emb)

            seq_len = key_layer.shape[0]
            offset = 0
            if exists(layer_past) and layer_past.numel() > 0:
                offset = layer_past[0].shape[0]
                seq_len += offset
            cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
            query_layer, key_layer = apply_rotary_fn(query_rot,
                                                     key_rot,
                                                     cos,
                                                     sin,
                                                     offset=offset)

            if exists(self.rotary_ndims):
                query_layer = torch.cat((query_layer, query_pass), dim=-1)
                key_layer = torch.cat((key_layer, key_pass), dim=-1)

        # ==================================
        # Cache key and value for inference
        # ==================================

        if exists(layer_past) and layer_past.numel() > 0:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
                                  dim=0)
            value_layer = torch.cat(
                (past_value.type_as(value_layer), value_layer), dim=0)

        if self.use_cache:
            present = torch.stack((key_layer, value_layer))

        if not self.sparse:
            context_layer = self.attention(query_layer, key_layer, value_layer,
                                           layer_past, attention_mask)
        else:
            context_layer = self.sparse_attention(query_layer, key_layer,
                                                  value_layer, attention_mask)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_partition, )
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.dense(context_layer)

        if self.use_cache:
            output = [output, present]

        return output, bias
Ejemplo n.º 4
0
    def attention(self, query_layer, key_layer, value_layer, layer_past,
                  attention_mask):
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

        # preallocating result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
        # Update attention mask for inference. [b, np, sq, sk]
        # ==================================================

        if self.use_cache:
            with torch.no_grad():
                attention_mask = attention_mask[
                    ..., :attention_scores.size(3), :attention_scores.size(3)]

        # ===========================
        # Attention probs and dropout
        # ===========================

        if exists(self.rpe):
            rpe = self.rpe(query_layer.size(0), key_layer.size(0))
            attention_scores += rpe  # [1, np, sq, sk]

        if self.pos_emb == "alibi":
            attention_scores = self.alibi_embed(attention_scores)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)
        return context_layer