def __init__(self, *args, block_size=16, num_random_blocks=None, sparse_attn_global_indices=[], **kwargs): super().__init__(*args, **kwargs) from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig self.block_size = block_size num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) global_blocks = uniq( map(lambda t: t // block_size, sparse_attn_global_indices)) self.attn_fn = SparseSelfAttention( sparsity_config=VariableSparsityConfig( num_heads=self.heads, block=self.block_size, num_random_blocks=num_random_blocks, global_block_indices=global_blocks, attention='unidirectional' if self.causal else 'bidirectional'), max_seq_length=self.seq_len, attn_mask_mode='add')
def __init__( self, *args, block_size = 16, num_random_blocks = None, **kwargs ): super().__init__(*args, **kwargs) assert not exists(self.tie_attn_dim), 'sparse attention is not compatible with tying of row attention' assert exists(self.seq_len), '`seq_len` must be defined if using sparse attention class' from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig self.block_size = block_size num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) self.attn_fn = SparseSelfAttention( sparsity_config = VariableSparsityConfig( num_heads = self.heads, block = self.block_size, num_random_blocks = num_random_blocks, attention = 'bidirectional' ), max_seq_length = self.seq_len, attn_mask_mode = 'add' )
def __init__(self, dim, heads, seq_len, causal=True, dim_head=64, dropout=0., sparse_attn=False): super().__init__() inner_dim = heads * dim_head self.causal = causal self.heads = heads self.scale = dim_head**-0.5 self.dropout = nn.Dropout(dropout) if sparse_attn: from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig sparsity_config = VariableSparsityConfig( num_heads=heads, attention=("unidirectional" if causal else "bidirectional")) self.attn_fn = SparseSelfAttention(sparsity_config=sparsity_config, max_seq_length=seq_len, attn_mask_mode='add') else: self.attn_fn = partial(dense_attn, dropout_fn=self.dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Linear(inner_dim, dim)
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number, sparse=False, rpe=None): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) # TODO: why do we start from 1 here? # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( args.hidden_size, args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( args.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * args.hidden_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.rpe = rpe self.sparse = sparse if self.sparse: sparsity_config = VariableSparsityConfig( num_heads=self.num_attention_heads_per_partition, attention="unidirectional" ) try: self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add', mpu=mpu) except TypeError: # older versions don't have the mpu arg self.sparse_attn = SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=args.seq_length, attn_mask_mode='add') else: self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, args.scaled_upper_triang_masked_softmax_fusion, args.scaled_masked_softmax_fusion, self.attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint
def configure_sparse_attention(neox_args, attention_type, num_attention_heads, mpu): from deepspeed.ops.sparse_attention import ( SparseSelfAttention, VariableSparsityConfig, FixedSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig, ) from deepspeed.ops.sparse_attention.sparsity_config import ( LocalSlidingWindowSparsityConfig, ) if attention_type == "sparse_fixed": # you can think of local window size as `block_size` * `num_local_blocks`. # so if you wanted to set a local window size of 256, set block size to 16 and `num_local_blocks` to 16 sparsity_config = FixedSparsityConfig( num_heads=num_attention_heads, block=neox_args.sparsity_config.get("block", 16), different_layout_per_head=neox_args.sparsity_config.get( "different_layout_per_head", False), num_local_blocks=neox_args.sparsity_config.get( "num_local_blocks", 4), num_global_blocks=neox_args.sparsity_config.get( "num_global_blocks", 1), num_different_global_patterns=neox_args.sparsity_config.get( "num_different_global_patterns", 1), attention="unidirectional", horizontal_global_attention=False, ) elif attention_type == "sparse_variable": sparsity_config = VariableSparsityConfig( num_heads=num_attention_heads, block=neox_args.sparsity_config.get("block", 16), different_layout_per_head=neox_args.sparsity_config.get( "different_layout_per_head", False), num_random_blocks=neox_args.sparsity_config.get( "num_random_blocks", 0), local_window_blocks=neox_args.sparsity_config.get( "local_window_blocks", [4]), global_block_indices=neox_args.sparsity_config.get( "global_block_indices", [0]), global_block_end_indices=neox_args.sparsity_config.get( "global_block_end_indices", None), attention="unidirectional", horizontal_global_attention=False, ) elif attention_type == "local": # can configure with `num_local_blocks` or `num_sliding_window_blocks` num_local_blocks = neox_args.sparsity_config.get( "num_local_blocks", neox_args.sparsity_config.get("num_sliding_window_blocks", 4), ) sparsity_config = LocalSlidingWindowSparsityConfig( num_heads=num_attention_heads, block=neox_args.sparsity_config.get("block", 16), num_sliding_window_blocks=num_local_blocks, attention="unidirectional", ) elif attention_type == "bigbird": sparsity_config = BigBirdSparsityConfig( num_heads=num_attention_heads, block=neox_args.sparsity_config.get("block", 16), different_layout_per_head=neox_args.sparsity_config.get( "different_layout_per_head", False), num_random_blocks=neox_args.sparsity_config.get( "num_random_blocks", 1), num_sliding_window_blocks=neox_args.sparsity_config.get( "num_sliding_window_blocks", 3), num_global_blocks=neox_args.sparsity_config.get( "num_global_blocks", 1), attention="unidirectional", ) elif attention_type == "bslongformer": sparsity_config = BSLongformerSparsityConfig( num_heads=num_attention_heads, block=neox_args.sparsity_config.get("block", 16), different_layout_per_head=neox_args.sparsity_config.get( "different_layout_per_head", False), num_sliding_window_blocks=neox_args.sparsity_config.get( "num_sliding_window_blocks", 3), global_block_indices=neox_args.sparsity_config.get( "global_block_indices", [0]), global_block_end_indices=neox_args.sparsity_config.get( "global_block_end_indices", None), attention="unidirectional", ) else: raise ValueError(f"Attention type {attention_type} not recognized") return SparseSelfAttention( sparsity_config=sparsity_config, max_seq_length=neox_args.seq_length, attn_mask_mode="add", mpu=mpu, )