def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() if args.geglu: self.activation_type = "geglu" mult = 8 self.activation_func = GEGLU() else: self.activation_type = "gelu" mult = 4 self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, mult * args.hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True)
def __init__(self, neox_args, init_method, output_layer_init_method, parallel_output=False): super().__init__() self.activation_func = get_activation(neox_args) self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion # auto scale so geglu has equal parameters ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4 ff_dim = (int(ff_mult * neox_args.hidden_size) * 2 if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size) self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim, gather_output=False, init_method=init_method, skip_bias_add=True, ) ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim_in, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )
def __init__( self, neox_args, parallel_output=True, init_method=nn.init.xavier_normal_, ): super().__init__() parallelism = neox_args.output_layer_parallelism if parallelism == "column": self.final_linear = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.padded_vocab_size, bias=False, init_method=init_method, gather_output=not parallel_output, skip_bias_add=False, ) else: self.final_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.padded_vocab_size, bias=False, input_is_parallel=False, init_method=init_method, parallel_output=parallel_output, skip_bias_add=False, )
def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear(args.hidden_size, 4 * args.hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if self.dense_h_to_4h.bias is not None: deepspeed.zero.register_external_parameter(self, self.dense_h_to_4h.bias)
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__(self, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. if not args.memory_centric_tiled_linear: self.dense_h_to_4h = mpu.ColumnParallelLinear( args.hidden_size, 4 * args.hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) else: self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( in_features=args.hidden_size, out_features=4 * args.hidden_size, linear_cls=mpu.ColumnParallelLinear, in_splits=args.tile_factor, out_splits=4 * args.tile_factor, combine_out_splits=True, gather_output=False, init_method=init_method, skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project back to h. if not args.memory_centric_tiled_linear: self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) else: self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias( in_features=4 * args.hidden_size, out_features=args.hidden_size, linear_cls=mpu.RowParallelLinear, in_splits=4 * args.tile_factor, out_splits=args.tile_factor, input_is_already_split=False, combine_out_splits=True, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True)
def __init__( self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None, ): super().__init__() self.layer_number = layer_number ff_dim = neox_args.hidden_size * ff_mult norm, eps = get_norm(neox_args) self.norm = norm(neox_args.hidden_size, eps=eps) self.input_linear = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim * 2, gather_output=False, init_method=init_method, skip_bias_add=True, ) self.activation_func = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim else: d_attn = None self.sgu = SpatialGatingUnit(neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn) self.output_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, )
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, stride=3, gather_output=False, init_method=init_method) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method) self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def __init__(self, mlp_activation_func, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear(args.hidden_size, 4 * args.hidden_size, gather_output=False, init_method=init_method) self.activation_func = mlp_activation_func # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number, sparse=False, rpe=None): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # TODO: why do we start from 1 here? # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.rpe = rpe self.sparse = sparse if self.sparse: sparsity_config = VariableSparsityConfig( num_heads=self.num_attention_heads_per_partition, attention="unidirectional" ) try: self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add', mpu=mpu) except TypeError: # older versions don't have the mpu arg self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add') else: self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def __init__( self, neox_args, attention_mask_func, init_method, output_layer_init_method, layer_number, rpe=None, rotary=False, use_cache=False, parallel_output=False, ): super().__init__() self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling self.use_cache = use_cache self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( neox_args.num_attention_heads, world_size) self.pos_emb = neox_args.pos_emb # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, gather_output=False, init_method=init_method, ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff self.rpe = rpe if self.pos_emb == "alibi": self.alibi_embed = AliBi( neox_args.num_attention_heads, neox_args.model_parallel_size, mpu.get_model_parallel_rank(), ) # TODO: this arg shouldn't need to be passed in - get from neox_args if rotary: if neox_args.rotary_pct == 1: self.rotary_ndims = None else: assert neox_args.rotary_pct < 1 self.rotary_ndims = int(self.hidden_size_per_attention_head * neox_args.rotary_pct) dim = (self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head) self.rotary_emb = RotaryEmbedding(dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype) else: self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] self.sparse = self.attention_type != "global" if self.sparse: self.sparse_attn = configure_sparse_attention( neox_args, self.attention_type, self.num_attention_heads_per_partition, mpu=mpu, ) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, input_in_bf16=self.bf16, fusion_type=get_fusion_type(neox_args), mask_func=self.attention_mask_func, softmax_in_fp32=self.attention_softmax_in_fp32, scale=coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = nn.Dropout(neox_args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )
def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): super(ParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(projection_size, world_size) self.hidden_size_per_attention_head = mpu.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear(args.hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = mpu.ColumnParallelLinear(args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( projection_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True)