def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__( self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None, ): super().__init__() self.layer_number = layer_number ff_dim = neox_args.hidden_size * ff_mult norm, eps = get_norm(neox_args) self.norm = norm(neox_args.hidden_size, eps=eps) self.input_linear = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim * 2, gather_output=False, init_method=init_method, skip_bias_add=True, ) self.activation_func = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim else: d_attn = None self.sgu = SpatialGatingUnit(neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn) self.output_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, )
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, stride=3, gather_output=False, init_method=init_method) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method) self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number, sparse=False, rpe=None): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # TODO: why do we start from 1 here? # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.rpe = rpe self.sparse = sparse if self.sparse: sparsity_config = VariableSparsityConfig( num_heads=self.num_attention_heads_per_partition, attention="unidirectional" ) try: self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add', mpu=mpu) except TypeError: # older versions don't have the mpu arg self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add') else: self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__( self, neox_args, attention_mask_func, init_method, output_layer_init_method, layer_number, rpe=None, rotary=False, use_cache=False, parallel_output=False, ): super().__init__() self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling self.use_cache = use_cache self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( neox_args.num_attention_heads, world_size) self.pos_emb = neox_args.pos_emb # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, gather_output=False, init_method=init_method, ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff self.rpe = rpe if self.pos_emb == "alibi": self.alibi_embed = AliBi( neox_args.num_attention_heads, neox_args.model_parallel_size, mpu.get_model_parallel_rank(), ) # TODO: this arg shouldn't need to be passed in - get from neox_args if rotary: if neox_args.rotary_pct == 1: self.rotary_ndims = None else: assert neox_args.rotary_pct < 1 self.rotary_ndims = int(self.hidden_size_per_attention_head * neox_args.rotary_pct) dim = (self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head) self.rotary_emb = RotaryEmbedding(dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype) else: self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] self.sparse = self.attention_type != "global" if self.sparse: self.sparse_attn = configure_sparse_attention( neox_args, self.attention_type, self.num_attention_heads_per_partition, mpu=mpu, ) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, input_in_bf16=self.bf16, fusion_type=get_fusion_type(neox_args), mask_func=self.attention_mask_func, softmax_in_fp32=self.attention_softmax_in_fp32, scale=coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = nn.Dropout(neox_args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )