Example #1
0
    def __init__(self,
                 *args,
                 block_size=16,
                 num_random_blocks=None,
                 sparse_attn_global_indices=[],
                 **kwargs):
        super().__init__(*args, **kwargs)
        from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
        self.block_size = block_size

        num_random_blocks = default(num_random_blocks,
                                    self.seq_len // block_size // 4)
        global_blocks = uniq(
            map(lambda t: t // block_size, sparse_attn_global_indices))

        self.attn_fn = SparseSelfAttention(
            sparsity_config=VariableSparsityConfig(
                num_heads=self.heads,
                block=self.block_size,
                num_random_blocks=num_random_blocks,
                global_block_indices=global_blocks,
                attention='unidirectional'
                if self.causal else 'bidirectional'),
            max_seq_length=self.seq_len,
            attn_mask_mode='add')
Example #2
0
    def __init__(
        self,
        *args,
        block_size = 16,
        num_random_blocks = None,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        assert not exists(self.tie_attn_dim), 'sparse attention is not compatible with tying of row attention'
        assert exists(self.seq_len), '`seq_len` must be defined if using sparse attention class'
        from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig

        self.block_size = block_size
        num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)

        self.attn_fn = SparseSelfAttention(
            sparsity_config = VariableSparsityConfig(
                num_heads = self.heads,
                block = self.block_size,
                num_random_blocks = num_random_blocks,
                attention = 'bidirectional'
            ),
            max_seq_length = self.seq_len,
            attn_mask_mode = 'add'
        )
Example #3
0
    def __init__(self,
                 dim,
                 heads,
                 seq_len,
                 causal=True,
                 dim_head=64,
                 dropout=0.,
                 sparse_attn=False):
        super().__init__()
        inner_dim = heads * dim_head
        self.causal = causal
        self.heads = heads
        self.scale = dim_head**-0.5
        self.dropout = nn.Dropout(dropout)

        if sparse_attn:
            from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig

            sparsity_config = VariableSparsityConfig(
                num_heads=heads,
                attention=("unidirectional" if causal else "bidirectional"))

            self.attn_fn = SparseSelfAttention(sparsity_config=sparsity_config,
                                               max_seq_length=seq_len,
                                               attn_mask_mode='add')
        else:
            self.attn_fn = partial(dense_attn, dropout_fn=self.dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)
Example #4
0
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number, sparse=False,
                 rpe=None):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)  # TODO: why do we start from 1 here?
        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            args.hidden_size, args.num_attention_heads)
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
            args.hidden_size,
            3 * args.hidden_size,
            gather_output=False,
            init_method=init_method)

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.rpe = rpe

        self.sparse = sparse
        if self.sparse:
            sparsity_config = VariableSparsityConfig(
                num_heads=self.num_attention_heads_per_partition,
                attention="unidirectional"
            )
            try:
                self.sparse_attn = SparseSelfAttention(
                    sparsity_config=sparsity_config,
                    max_seq_length=args.seq_length,
                    attn_mask_mode='add',
                    mpu=mpu)
            except TypeError:
                # older versions don't have the mpu arg
                self.sparse_attn = SparseSelfAttention(
                    sparsity_config=sparsity_config,
                    max_seq_length=args.seq_length,
                    attn_mask_mode='add')
        else:
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                self.fp16,
                args.scaled_upper_triang_masked_softmax_fusion,
                args.scaled_masked_softmax_fusion,
                self.attention_mask_func,
                self.attention_softmax_in_fp32,
                coeff)

            # Dropout. Note that for a single iteration, this layer will generate
            # different outputs on different number of parallel partitions but
            # on average it should not be partition dependent.
            self.attention_dropout = torch.nn.Dropout(args.attention_dropout)

        # Output.
        self.dense = mpu.RowParallelLinear(
            args.hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
Example #5
0
def configure_sparse_attention(neox_args, attention_type, num_attention_heads,
                               mpu):
    from deepspeed.ops.sparse_attention import (
        SparseSelfAttention,
        VariableSparsityConfig,
        FixedSparsityConfig,
        BigBirdSparsityConfig,
        BSLongformerSparsityConfig,
    )
    from deepspeed.ops.sparse_attention.sparsity_config import (
        LocalSlidingWindowSparsityConfig, )

    if attention_type == "sparse_fixed":
        # you can think of local window size as `block_size` * `num_local_blocks`.
        # so if you wanted to set a local window size of 256, set block size to 16 and `num_local_blocks` to 16
        sparsity_config = FixedSparsityConfig(
            num_heads=num_attention_heads,
            block=neox_args.sparsity_config.get("block", 16),
            different_layout_per_head=neox_args.sparsity_config.get(
                "different_layout_per_head", False),
            num_local_blocks=neox_args.sparsity_config.get(
                "num_local_blocks", 4),
            num_global_blocks=neox_args.sparsity_config.get(
                "num_global_blocks", 1),
            num_different_global_patterns=neox_args.sparsity_config.get(
                "num_different_global_patterns", 1),
            attention="unidirectional",
            horizontal_global_attention=False,
        )
    elif attention_type == "sparse_variable":
        sparsity_config = VariableSparsityConfig(
            num_heads=num_attention_heads,
            block=neox_args.sparsity_config.get("block", 16),
            different_layout_per_head=neox_args.sparsity_config.get(
                "different_layout_per_head", False),
            num_random_blocks=neox_args.sparsity_config.get(
                "num_random_blocks", 0),
            local_window_blocks=neox_args.sparsity_config.get(
                "local_window_blocks", [4]),
            global_block_indices=neox_args.sparsity_config.get(
                "global_block_indices", [0]),
            global_block_end_indices=neox_args.sparsity_config.get(
                "global_block_end_indices", None),
            attention="unidirectional",
            horizontal_global_attention=False,
        )
    elif attention_type == "local":
        # can configure with `num_local_blocks` or `num_sliding_window_blocks`
        num_local_blocks = neox_args.sparsity_config.get(
            "num_local_blocks",
            neox_args.sparsity_config.get("num_sliding_window_blocks", 4),
        )
        sparsity_config = LocalSlidingWindowSparsityConfig(
            num_heads=num_attention_heads,
            block=neox_args.sparsity_config.get("block", 16),
            num_sliding_window_blocks=num_local_blocks,
            attention="unidirectional",
        )
    elif attention_type == "bigbird":
        sparsity_config = BigBirdSparsityConfig(
            num_heads=num_attention_heads,
            block=neox_args.sparsity_config.get("block", 16),
            different_layout_per_head=neox_args.sparsity_config.get(
                "different_layout_per_head", False),
            num_random_blocks=neox_args.sparsity_config.get(
                "num_random_blocks", 1),
            num_sliding_window_blocks=neox_args.sparsity_config.get(
                "num_sliding_window_blocks", 3),
            num_global_blocks=neox_args.sparsity_config.get(
                "num_global_blocks", 1),
            attention="unidirectional",
        )
    elif attention_type == "bslongformer":
        sparsity_config = BSLongformerSparsityConfig(
            num_heads=num_attention_heads,
            block=neox_args.sparsity_config.get("block", 16),
            different_layout_per_head=neox_args.sparsity_config.get(
                "different_layout_per_head", False),
            num_sliding_window_blocks=neox_args.sparsity_config.get(
                "num_sliding_window_blocks", 3),
            global_block_indices=neox_args.sparsity_config.get(
                "global_block_indices", [0]),
            global_block_end_indices=neox_args.sparsity_config.get(
                "global_block_end_indices", None),
            attention="unidirectional",
        )
    else:
        raise ValueError(f"Attention type {attention_type} not recognized")
    return SparseSelfAttention(
        sparsity_config=sparsity_config,
        max_seq_length=neox_args.seq_length,
        attn_mask_mode="add",
        mpu=mpu,
    )