def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__(self, neox_args, d_attn, d_ff, mask_fn): super().__init__() self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn) self.scale = d_attn**-0.5 self.proj_ffn = nn.Linear(d_attn, d_ff) self.softmax = FusedScaleMaskSoftmax( input_in_fp16=neox_args.precision == "fp16", input_in_bf16=neox_args.precision == "bfloat16", fusion_type=get_fusion_type(neox_args), mask_func=mask_fn, softmax_in_fp32=neox_args.attention_softmax_in_fp32, scale=None, )
def __init__(self, neox_args, d_attn, d_ff, mask_fn): super().__init__() self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn) self.scale = d_attn**-0.5 self.proj_ffn = nn.Linear(d_attn, d_ff) self.softmax = FusedScaleMaskSoftmax( input_in_fp16=neox_args.precision == "fp16", upper_triang_mask_fusion=neox_args. scaled_upper_triang_masked_softmax_fusion, general_mask_fusion=neox_args.scaled_masked_softmax_fusion, mask_func=mask_fn, softmax_in_fp32=neox_args.attention_softmax_in_fp32, scale=None)
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number, sparse=False, rpe=None): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # TODO: why do we start from 1 here? # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.rpe = rpe self.sparse = sparse if self.sparse: sparsity_config = VariableSparsityConfig( num_heads=self.num_attention_heads_per_partition, attention="unidirectional" ) try: self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add', mpu=mpu) except TypeError: # older versions don't have the mpu arg self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add') else: self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__( self, neox_args, attention_mask_func, init_method, output_layer_init_method, layer_number, rpe=None, rotary=False, use_cache=False, parallel_output=False, ): super().__init__() self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling self.use_cache = use_cache self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( neox_args.num_attention_heads, world_size) self.pos_emb = neox_args.pos_emb # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, gather_output=False, init_method=init_method, ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff self.rpe = rpe if self.pos_emb == "alibi": self.alibi_embed = AliBi( neox_args.num_attention_heads, neox_args.model_parallel_size, mpu.get_model_parallel_rank(), ) # TODO: this arg shouldn't need to be passed in - get from neox_args if rotary: if neox_args.rotary_pct == 1: self.rotary_ndims = None else: assert neox_args.rotary_pct < 1 self.rotary_ndims = int(self.hidden_size_per_attention_head * neox_args.rotary_pct) dim = (self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head) self.rotary_emb = RotaryEmbedding(dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype) else: self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] self.sparse = self.attention_type != "global" if self.sparse: self.sparse_attn = configure_sparse_attention( neox_args, self.attention_type, self.num_attention_heads_per_partition, mpu=mpu, ) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, input_in_bf16=self.bf16, fusion_type=get_fusion_type(neox_args), mask_func=self.attention_mask_func, softmax_in_fp32=self.attention_softmax_in_fp32, scale=coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = nn.Dropout(neox_args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )
def test_fused_softmax(): from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes from megatron.model.gpt2_model import ( gpt2_attention_mask_func as attention_mask_func, ) bert = BertModel.from_pretrained("bert-base-cased").cuda().half() tokenizer = BertTokenizer.from_pretrained("bert-base-cased") test_text = ( "Hello. How are you? I am fine thank you and you? yes Good. " "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 ) tokens = tokenizer( [test_text] * 4, return_tensors="pt", ) embedding_output = bert.embeddings( input_ids=tokens["input_ids"].cuda(), position_ids=None, token_type_ids=tokens["token_type_ids"].cuda(), inputs_embeds=None, past_key_values_length=0, ) # (bsz, 1, 1, seq_len) mask = bert.get_extended_attention_mask( attention_mask=tokens["attention_mask"].cuda(), input_shape=tokens["input_ids"].shape, device=bert.device, ) # (bsz, 1, seq_len, seq_len) mask = mask.repeat(1, 1, mask.size()[-1], 1) attention = bert.encoder.layer[0].attention.self key_layer = attention.transpose_for_scores(attention.key(embedding_output)) query_layer = attention.transpose_for_scores( attention.query(embedding_output)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores /= math.sqrt(key_layer.size()[-1]) fused_softmax = (FusedScaleMaskSoftmax( input_in_fp16=True, input_in_bf16=False, fusion_type=SoftmaxFusionTypes.general, mask_func=attention_mask_func, scale=None, softmax_in_fp32=False, ).cuda().half()) fused_softmax_output = fused_softmax( attention_scores, (mask != 0), ) torch_softmax = (FusedScaleMaskSoftmax( input_in_fp16=True, input_in_bf16=False, mask_func=attention_mask_func, fusion_type=SoftmaxFusionTypes.none, scale=None, softmax_in_fp32=False, ).cuda().half()) torch_softmax_output = torch_softmax( attention_scores, (mask != 0), ) test_result = (fused_softmax_output - torch_softmax_output).abs() while test_result.dim() != 1: test_result = test_result.mean(dim=-1) diff = test_result.mean(dim=-1) if diff <= 1e-3: print( f"\n[Success] test_fused_softmax" f"\n > mean_difference={diff}" f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" ) else: print( f"\n[Fail] test_fused_softmax" f"\n > mean_difference={diff}, " f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" )
def test_fused_upper_triangle_mask_softmax(): from megatron.model.gpt2_model import ( gpt2_attention_mask_func as attention_mask_func, ) from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes gpt = GPT2Model.from_pretrained("gpt2").cuda().half() tokenizer = GPT2Tokenizer.from_pretrained("gpt2") test_text = ( "Hello. How are you? I am fine thank you and you? yes Good. " "hi hi hi hi hi hi hi" # 24 ) tokens = tokenizer( [test_text] * 4, return_tensors="pt", ) attention_mask = tokens["attention_mask"].cuda() attention_mask = attention_mask.view(attention_mask.size(0), -1) attention_mask = attention_mask[:, None, None, :] attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) attn = gpt.h[0] hidden_states = gpt.wte(tokens["input_ids"].cuda()) q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) attn_weights = torch.matmul(q, k.transpose(-1, -2)) sq, sk = q.size(-2), k.size(-2) causal_mask = attn.attn.bias[:, :, sk - sq:sk, :sk].bool() total_mask = ~(causal_mask & (attention_mask == 0)) """ tensor([[[[False, True, True, ..., True, True, True], [False, False, True, ..., True, True, True], [False, False, False, ..., True, True, True], ..., [False, False, False, ..., False, True, True], [False, False, False, ..., False, False, True], [False, False, False, ..., False, False, False]]] """ fused_softmax = (FusedScaleMaskSoftmax( input_in_fp16=True, input_in_bf16=False, mask_func=attention_mask_func, fusion_type=SoftmaxFusionTypes.upper_triang, scale=None, softmax_in_fp32=False, ).cuda().half()) fused_softmax_output = fused_softmax( attn_weights, total_mask, ) torch_softmax = (FusedScaleMaskSoftmax( input_in_fp16=True, input_in_bf16=False, fusion_type=SoftmaxFusionTypes.none, mask_func=attention_mask_func, scale=None, softmax_in_fp32=False, ).cuda().half()) torch_softmax_output = torch_softmax( attn_weights, total_mask, ) test_result = (fused_softmax_output - torch_softmax_output).abs() while test_result.dim() != 1: test_result = test_result.mean(dim=-1) diff = test_result.mean(dim=-1) if diff <= 1e-3: print( f"\n[Success] test_fused_upper_triangle_mask_softmax" f"\n > mean_difference={diff}" f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" ) else: print( f"\n[Fail] test_fused_upper_triangle_mask_softmax" f"\n > mean_difference={diff}, " f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" )
def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): super(ParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(projection_size, world_size) self.hidden_size_per_attention_head = mpu.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear(args.hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = mpu.ColumnParallelLinear(args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( projection_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True)