def __init__( self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), dropout = 0., no_projection = False, qkv_bias = False, attn_out_bias = True ): super().__init__() assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) self.heads = heads self.global_heads = heads - local_heads self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) self.dropout = nn.Dropout(dropout)
def __init__(self, dim, heads = 8, bucket_size = 64, n_hashes = 8, causal = False, dim_head = None, attn_chunks = 1, random_rotations_per_head = False, attend_across_buckets = True, allow_duplicate_attention = True, num_mem_kv = 0, one_value_head = False, use_full_attn = False, full_attn_thres = None, return_attn = False, post_attn_dropout = 0., dropout = 0., n_local_attn_heads = 0, **kwargs): super().__init__() assert dim_head or (dim % heads) == 0, 'dimensions must be divisible by number of heads' assert n_local_attn_heads < heads, 'local attention heads must be less than number of heads' dim_head = default(dim_head, dim // heads) dim_heads = dim_head * heads self.dim = dim self.heads = heads self.dim_head = dim_head self.attn_chunks = default(attn_chunks, 1) self.v_head_repeats = (heads if one_value_head else 1) v_dim = dim_heads // self.v_head_repeats self.toqk = nn.Linear(dim, dim_heads, bias = False) self.tov = nn.Linear(dim, v_dim, bias = False) self.to_out = nn.Linear(dim_heads, dim) self.bucket_size = bucket_size self.lsh_attn = LSHAttention(bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, random_rotations_per_head=random_rotations_per_head, attend_across_buckets = attend_across_buckets, allow_duplicate_attention = allow_duplicate_attention, return_attn = return_attn, dropout = dropout, **kwargs) self.full_attn = FullQKAttention(causal=causal, dropout=dropout) self.post_attn_dropout = nn.Dropout(post_attn_dropout) self.use_full_attn = use_full_attn self.full_attn_thres = default(full_attn_thres, bucket_size) self.num_mem_kv = num_mem_kv self.mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) if num_mem_kv > 0 else None self.n_local_attn_heads = n_local_attn_heads self.local_attn = LocalAttention(window_size=bucket_size * 2, causal=causal, dropout=dropout, shared_qk=True, look_forward=(1 if not causal else 0)) self.callback = None
def __init__(self, dim, heads, causal = False, dim_head = None, blindspot_size = 1, n_local_attn_heads = 0, local_attn_window_size = 128, receives_context = False, dropout = 0., attn_dropout = 0.): super().__init__() assert dim_head or (dim % heads) == 0, 'embedding dimension must be divisible by number of heads' d_heads = default(dim_head, dim // heads) self.heads = heads self.d_heads = d_heads self.receives_context = receives_context self.global_attn_heads = heads - n_local_attn_heads self.global_attn_fn = linear_attn if not causal else partial(causal_linear_attn, bucket_size = blindspot_size) self.local_attn_heads = n_local_attn_heads self.local_attn = LocalAttention(local_attn_window_size, causal = causal, dropout = attn_dropout) self.to_q = nn.Linear(dim, d_heads * heads, bias = False) kv_heads = heads self.kv_heads = kv_heads self.to_k = nn.Linear(dim, d_heads * kv_heads, bias = False) self.to_v = nn.Linear(dim, d_heads * kv_heads, bias = False) self.to_out = nn.Linear(d_heads * heads, dim) self.dropout = nn.Dropout(dropout)
def __init__(self, dim, depth, max_seq_len, heads, local_attn_heads, window_size, dim_head = None, local_attn_window_size = None, local_attn_radius_blocks = 1, causal = False, attn_dropout = 0., dropout = 0., kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, num_mem_kv = 0, shared_qk = False, conv_query_kernel = 9): super().__init__() assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads' assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size' assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads' assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context' assert not (receives_context and causal), 'contextual attention layer cannot be causal' local_attn_window_size = default(local_attn_window_size, window_size) context_window_size = default(context_window_size, window_size) self.shared_qk = shared_qk self.receives_context = receives_context self.heads = heads self.local_attn_heads = local_attn_heads self.global_attn_heads = heads - local_attn_heads self.causal = causal self.window_size = window_size dim_head = default(dim_head, dim // heads) dim_heads = dim_head * heads self.dim_head = dim_head num_clusters = max_seq_len // window_size # local local_dim_heads = dim_head * self.local_attn_heads if self.local_attn_heads > 0: rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None self.local_attn = LocalAttention(dim = dim_head, window_size = local_attn_window_size, causal = causal, dropout = attn_dropout, rel_pos_emb_config = rel_pos_emb_config, look_backward = local_attn_radius_blocks, look_forward = 0 if causal else local_attn_radius_blocks) self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) # global global_dim_heads = dim_head * self.global_attn_heads if self.global_attn_heads > 0: self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal = causal, dropout = attn_dropout, ema_decay = kmeans_ema_decay, commitment = commitment_factor, receives_context = receives_context, num_mem_kv = num_mem_kv, shared_qk = shared_qk) self.to_q = nn.Linear(dim, global_dim_heads, bias = False) self.to_v = nn.Linear(dim, global_dim_heads, bias = False) if not self.shared_qk: self.to_k = nn.Linear(dim, global_dim_heads, bias = False) # out self.to_out = nn.Linear(dim_heads, dim, bias = False) self.dropout = nn.Dropout(dropout)
def __init__(self, dim, causal=False, heads=8, local_heads=0, local_window_size=256, nb_features=None, redraw_projection=True, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.): super().__init__() assert dim % heads == 0, 'dimension must be divisible by number of heads' self.fast_attention = FastAttention(dim // heads, nb_features, redraw_projection, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q) self.heads = heads self.global_heads = heads - local_heads self.local_attn = LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal)) if local_heads > 0 else None self.to_q = nn.Linear(dim, dim) self.to_k = nn.Linear(dim, dim) self.to_v = nn.Linear(dim, dim) self.to_out = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout)
def __init__(self, dim, causal=False, heads=8, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0., amp_enabled=False): super().__init__() assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = dim // heads self.fast_attention = FastAttention( dim_head, nb_features, feature_redraw_interval, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, amp_enabled=amp_enabled) self.heads = heads self.global_heads = heads - local_heads self.local_attn = LocalAttention( window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None self.to_q = nn.Linear(dim, dim) self.to_k = nn.Linear(dim, dim) self.to_v = nn.Linear(dim, dim) self.to_out = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout)
def __init__(self, dim, bucket_size, max_seq_len, heads=8, dim_head=None, kv_bucket_size=None, causal=False, non_permutative=True, sinkhorn_iter=5, n_sortcut=0, temperature=0.75, attn_dropout=0., dropout=0., context_only=False, use_simple_sort_net=False, n_local_attn_heads=0, n_top_buckets=1): super().__init__() assert dim_head or divisible_by( dim, heads ), f'If dim_head is None, dimension {dim} must be divisible by the number of heads {heads}' assert not (causal and n_sortcut > 0 ), 'sortcut can only be used for non causal attention' assert not (causal and context_only ), 'context only self attention layer cannot be causal' assert n_local_attn_heads <= heads, 'number of local attention heads cannot exceed total heads' dim_head = default(dim_head, dim // heads) dim_heads = dim_head * heads self.dim_head = dim_head self.heads = heads self.bucket_size = bucket_size self.kv_bucket_size = default(kv_bucket_size, bucket_size) self.context_only = context_only self.to_q = nn.Linear(dim, dim_heads, bias=False) self.to_kv = nn.Linear(dim, dim_heads * 2, bias=False) if not context_only else None self.to_out = nn.Linear(dim_heads, dim) self.n_local_attn_heads = n_local_attn_heads self.local_attention = LocalAttention( bucket_size, causal, dropout=attn_dropout, look_forward=(1 if not causal else 0)) sink_heads = heads - n_local_attn_heads if causal: attn = SinkhornCausalAttention( bucket_size, dim, dim_head, sink_heads, max_seq_len, dropout=attn_dropout, kv_bucket_size=kv_bucket_size, use_simple_sort_net=use_simple_sort_net, n_top_buckets=n_top_buckets, temperature=temperature) else: attn = SinkhornAttention(bucket_size, dim, dim_head, sink_heads, max_seq_len, non_permutative=non_permutative, sinkhorn_iter=sinkhorn_iter, n_sortcut=n_sortcut, temperature=temperature, dropout=attn_dropout, kv_bucket_size=kv_bucket_size, use_simple_sort_net=use_simple_sort_net, n_top_buckets=n_top_buckets) self.sinkhorn_attention = attn self.dropout = nn.Dropout(dropout)