def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): # TODO: sparse attn dropout? # TODO: pad to block size # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] query_layer, key_layer, value_layer = map( lambda t: t.permute(1, 2, 0, 3).contiguous(), (query_layer, key_layer, value_layer), ) # output shape [b, np(heads), sq, hn] attn_mask = attention_mask.to(query_layer.dtype) * -10000 if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None return self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe)
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None, layer_past=None, get_key_value=False): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) if exists(rotary_pos_emb): query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) if not self.sparse: # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocating result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) attention_scores += rpe # [1, np, sq, sk] # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) else: # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] query_layer, key_layer, value_layer = map(lambda t: t.permute(1, 2, 0, 3).contiguous(), (query_layer, key_layer, value_layer)) # output shape [b, np(heads), sq, hn] attn_mask = attention_mask.to(query_layer.dtype) * -10000 if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None context_layer = self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias
def forward(self, hidden_states, attention_mask, layer_past=None): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) if exists(self.rotary_emb): if exists(self.rotary_ndims): # partial rotary query_rot, query_pass = ( query_layer[..., :self.rotary_ndims], query_layer[..., self.rotary_ndims:], ) key_rot, key_pass = ( key_layer[..., :self.rotary_ndims], key_layer[..., self.rotary_ndims:], ) else: # full rotary query_rot, key_rot = query_layer, key_layer apply_rotary_fn = (apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb) seq_len = key_layer.shape[0] offset = 0 if exists(layer_past) and layer_past.numel() > 0: offset = layer_past[0].shape[0] seq_len += offset cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) query_layer, key_layer = apply_rotary_fn(query_rot, key_rot, cos, sin, offset=offset) if exists(self.rotary_ndims): query_layer = torch.cat((query_layer, query_pass), dim=-1) key_layer = torch.cat((key_layer, key_pass), dim=-1) # ================================== # Cache key and value for inference # ================================== if exists(layer_past) and layer_past.numel() > 0: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=0) if self.use_cache: present = torch.stack((key_layer, value_layer)) if not self.sparse: context_layer = self.attention(query_layer, key_layer, value_layer, layer_past, attention_mask) else: context_layer = self.sparse_attention(query_layer, key_layer, value_layer, attention_mask) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + ( self.hidden_size_per_partition, ) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if self.use_cache: output = [output, present] return output, bias
def attention(self, query_layer, key_layer, value_layer, layer_past, attention_mask): # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0), ) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocating result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device(), ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if self.use_cache: with torch.no_grad(): attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) attention_scores += rpe # [1, np, sq, sk] if self.pos_emb == "alibi": attention_scores = self.alibi_embed(attention_scores) # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = ( value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3), ) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) return context_layer