Пример #1
0
    def __init__(self, init_method, output_layer_init_method):
        super(ParallelMLP, self).__init__()
        args = get_args()

        if args.geglu:
            self.activation_type = "geglu"
            mult = 8
            self.activation_func = GEGLU()
        else:
            self.activation_type = "gelu"
            mult = 4
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
            if args.openai_gelu:
                self.activation_func = openai_gelu
            elif args.onnx_safe:
                self.activation_func = erf_gelu

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
            args.hidden_size,
            mult * args.hidden_size,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True)

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
            4 * args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)
Пример #2
0
    def __init__(self,
                 neox_args,
                 init_method,
                 output_layer_init_method,
                 parallel_output=False):
        super().__init__()

        self.activation_func = get_activation(neox_args)
        self.activation_type = neox_args.activation
        self.bias_gelu_fusion = neox_args.bias_gelu_fusion

        # auto scale so geglu has equal parameters
        ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4
        ff_dim = (int(ff_mult * neox_args.hidden_size) *
                  2 if self.activation_type == "geglu" else ff_mult *
                  neox_args.hidden_size)
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
            neox_args=neox_args,
            input_size=neox_args.hidden_size,
            output_size=ff_dim,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
        )
        ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
            neox_args=neox_args,
            input_size=ff_dim_in,
            output_size=neox_args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True,
            parallel_output=parallel_output,
        )
Пример #3
0
 def __init__(
     self,
     neox_args,
     parallel_output=True,
     init_method=nn.init.xavier_normal_,
 ):
     super().__init__()
     parallelism = neox_args.output_layer_parallelism
     if parallelism == "column":
         self.final_linear = mpu.ColumnParallelLinear(
             neox_args=neox_args,
             input_size=neox_args.hidden_size,
             output_size=neox_args.padded_vocab_size,
             bias=False,
             init_method=init_method,
             gather_output=not parallel_output,
             skip_bias_add=False,
         )
     else:
         self.final_linear = mpu.RowParallelLinear(
             neox_args=neox_args,
             input_size=neox_args.hidden_size,
             output_size=neox_args.padded_vocab_size,
             bias=False,
             input_is_parallel=False,
             init_method=init_method,
             parallel_output=parallel_output,
             skip_bias_add=False,
         )
Пример #4
0
    def __init__(self, init_method, output_layer_init_method):
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(args.hidden_size,
                                                      4 * args.hidden_size,
                                                      gather_output=False,
                                                      init_method=init_method,
                                                      skip_bias_add=True)

        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
            4 * args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)

        if self.dense_h_to_4h.bias is not None:
            deepspeed.zero.register_external_parameter(self,
                                                       self.dense_h_to_4h.bias)
Пример #5
0
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            args.hidden_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
            args.hidden_size,
            3 * args.hidden_size,
            gather_output=False,
            init_method=init_method)

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, args.scaled_upper_triang_masked_softmax_fusion,
            args.scaled_masked_softmax_fusion, self.attention_mask_func,
            self.attention_softmax_in_fp32, coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
Пример #6
0
    def __init__(self, init_method, output_layer_init_method):
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        if not args.memory_centric_tiled_linear:
            self.dense_h_to_4h = mpu.ColumnParallelLinear(
                args.hidden_size,
                4 * args.hidden_size,
                gather_output=False,
                init_method=init_method,
                skip_bias_add=True)
        else:
            self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias(
                in_features=args.hidden_size,
                out_features=4 * args.hidden_size,
                linear_cls=mpu.ColumnParallelLinear,
                in_splits=args.tile_factor,
                out_splits=4 * args.tile_factor,
                combine_out_splits=True,
                gather_output=False,
                init_method=init_method,
                skip_bias_add=True)

        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu

        # Project back to h.
        if not args.memory_centric_tiled_linear:
            self.dense_4h_to_h = mpu.RowParallelLinear(
                4 * args.hidden_size,
                args.hidden_size,
                input_is_parallel=True,
                init_method=output_layer_init_method,
                skip_bias_add=True)
        else:
            self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias(
                in_features=4 * args.hidden_size,
                out_features=args.hidden_size,
                linear_cls=mpu.RowParallelLinear,
                in_splits=4 * args.tile_factor,
                out_splits=args.tile_factor,
                input_is_already_split=False,
                combine_out_splits=True,
                input_is_parallel=True,
                init_method=output_layer_init_method,
                skip_bias_add=True)
Пример #7
0
    def __init__(
        self,
        neox_args,
        init_method,
        output_layer_init_method,
        layer_number,
        ff_mult=4,
        mask_fn=None,
    ):
        super().__init__()
        self.layer_number = layer_number

        ff_dim = neox_args.hidden_size * ff_mult
        norm, eps = get_norm(neox_args)
        self.norm = norm(neox_args.hidden_size, eps=eps)
        self.input_linear = mpu.ColumnParallelLinear(
            neox_args=neox_args,
            input_size=neox_args.hidden_size,
            output_size=ff_dim * 2,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
        )
        self.activation_func = get_activation(neox_args)
        ff_dim_parallel = mpu.divide(ff_dim,
                                     mpu.get_model_parallel_world_size())
        if neox_args.attention_config[layer_number] == "amlp":
            d_attn = neox_args.gmlp_attn_dim
        else:
            d_attn = None
        self.sgu = SpatialGatingUnit(neox_args,
                                     ff_dim_parallel,
                                     d_attn,
                                     causal=True,
                                     mask_fn=mask_fn)
        self.output_linear = mpu.RowParallelLinear(
            neox_args=neox_args,
            input_size=ff_dim,
            output_size=neox_args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True,
        )
Пример #8
0
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            args.hidden_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
            args.hidden_size,
            3 * args.hidden_size,
            stride=3,
            gather_output=False,
            init_method=init_method)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method)
        self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
Пример #9
0
    def __init__(self, mlp_activation_func, init_method,
                 output_layer_init_method):
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(args.hidden_size,
                                                      4 * args.hidden_size,
                                                      gather_output=False,
                                                      init_method=init_method)

        self.activation_func = mlp_activation_func

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
            4 * args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method)

        self.dropout = torch.nn.Dropout(args.hidden_dropout)
Пример #10
0
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number, sparse=False,
                 rpe=None):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)  # TODO: why do we start from 1 here?
        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            args.hidden_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
            args.hidden_size,
            3 * args.hidden_size,
            gather_output=False,
            init_method=init_method)

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.rpe = rpe

        self.sparse = sparse
        if self.sparse:
            sparsity_config = VariableSparsityConfig(
                num_heads=self.num_attention_heads_per_partition,
                attention="unidirectional"
            )
            try:
                self.sparse_attn = SparseSelfAttention(
                    sparsity_config=sparsity_config,
                    max_seq_length=args.seq_length,
                    attn_mask_mode='add',
                    mpu=mpu)
            except TypeError:
                # older versions don't have the mpu arg
                self.sparse_attn = SparseSelfAttention(
                    sparsity_config=sparsity_config,
                    max_seq_length=args.seq_length,
                    attn_mask_mode='add')
        else:
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                self.fp16,
                args.scaled_upper_triang_masked_softmax_fusion,
                args.scaled_masked_softmax_fusion,
                self.attention_mask_func,
                self.attention_softmax_in_fp32,
                coeff)

            # Dropout. Note that for a single iteration, this layer will generate
            # different outputs on different number of parallel partitions but
            # on average it should not be partition dependent.
            self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
Пример #11
0
    def __init__(
        self,
        neox_args,
        attention_mask_func,
        init_method,
        output_layer_init_method,
        layer_number,
        rpe=None,
        rotary=False,
        use_cache=False,
        parallel_output=False,
    ):
        super().__init__()

        self.fp16 = neox_args.precision == "fp16"
        self.bf16 = neox_args.precision == "bfloat16"
        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling
        self.use_cache = use_cache
        self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = layer_number
        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            neox_args.hidden_size, neox_args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            neox_args.num_attention_heads, world_size)
        self.pos_emb = neox_args.pos_emb

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
            neox_args=neox_args,
            input_size=neox_args.hidden_size,
            output_size=3 * neox_args.hidden_size,
            gather_output=False,
            init_method=init_method,
        )

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = max(1, self.layer_number)
            self.norm_factor *= coeff

        self.rpe = rpe

        if self.pos_emb == "alibi":
            self.alibi_embed = AliBi(
                neox_args.num_attention_heads,
                neox_args.model_parallel_size,
                mpu.get_model_parallel_rank(),
            )

        # TODO: this arg shouldn't need to be passed in - get from neox_args
        if rotary:
            if neox_args.rotary_pct == 1:
                self.rotary_ndims = None
            else:
                assert neox_args.rotary_pct < 1
                self.rotary_ndims = int(self.hidden_size_per_attention_head *
                                        neox_args.rotary_pct)
            dim = (self.rotary_ndims if self.rotary_ndims is not None else
                   self.hidden_size_per_attention_head)
            self.rotary_emb = RotaryEmbedding(dim,
                                              base=neox_args.rotary_emb_base,
                                              precision=neox_args.params_dtype)
        else:
            self.rotary_emb = None

        self.attention_type = neox_args.attention_config[layer_number]
        self.sparse = self.attention_type != "global"
        if self.sparse:
            self.sparse_attn = configure_sparse_attention(
                neox_args,
                self.attention_type,
                self.num_attention_heads_per_partition,
                mpu=mpu,
            )
        else:
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                input_in_fp16=self.fp16,
                input_in_bf16=self.bf16,
                fusion_type=get_fusion_type(neox_args),
                mask_func=self.attention_mask_func,
                softmax_in_fp32=self.attention_softmax_in_fp32,
                scale=coeff,
            )

            # Dropout. Note that for a single iteration, this layer will generate
            # different outputs on different number of parallel partitions but
            # on average it should not be partition dependent.
            self.attention_dropout = nn.Dropout(neox_args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            neox_args=neox_args,
            input_size=neox_args.hidden_size,
            output_size=neox_args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True,
            parallel_output=parallel_output,
        )
Пример #12
0
    def __init__(self,
                 init_method,
                 output_layer_init_method,
                 layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
        world_size = mpu.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(projection_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            projection_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)

        # Strided linear layer.
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(args.hidden_size,
                                                  projection_size,
                                                  gather_output=False,
                                                  init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(args.hidden_size,
                                                      2 * projection_size,
                                                      gather_output=False,
                                                      init_method=init_method)

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16, self.attn_mask_type,
            args.masked_softmax_fusion, attention_mask_func,
            self.attention_softmax_in_fp32, coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            projection_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)