def __init__(self, dim, depth, max_seq_len, heads=8, dim_head=None, bucket_size=64, n_hashes=8, ff_chunks=100, attn_chunks=None, causal=False, weight_tie=False, lsh_dropout=0., ff_dropout=0., ff_activation=None, ff_mult=4, ff_glu=False, post_attn_dropout=0., layer_dropout=0., lsh_attend_across_buckets=True, lsh_allow_duplicate_attention=True, random_rotations_per_head=False, use_scale_norm=False, use_rezero=False, use_full_attn=False, full_attn_thres=0, k_means_hashing=False, reverse_thres=0, num_mem_kv=0, one_value_head=False, n_local_attn_heads=0, pkm_layers=tuple(), pkm_num_keys=128): super().__init__() self.dim = dim self.depth = depth self.bucket_size = bucket_size self.num_mem_kv = num_mem_kv self.full_attn_thres = full_attn_thres get_attn = lambda: LSHSelfAttention(dim=dim, max_seq_len=max_seq_len, heads=heads, bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, dropout=lsh_dropout, attn_chunks=attn_chunks, allow_duplicate_attention=lsh_allow_duplicate_attention, attend_across_buckets=lsh_attend_across_buckets, random_rotations_per_head=random_rotations_per_head, num_mem_kv=num_mem_kv, use_full_attn=use_full_attn, full_attn_thres=full_attn_thres, k_means_hashing=k_means_hashing) get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout=ff_dropout, activation=ff_activation, mult=ff_mult, glu=ff_glu), along_dim=-2) get_pkm = lambda: PKM(dim, num_keys=pkm_num_keys) if weight_tie: get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm)) blocks = [] norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim) for ind in range(depth): layer_num = ind + 1 use_pkm = layer_num in cast_tuple(pkm_layers) parallel_net = None attn = get_attn() if use_pkm: parallel_net = get_pkm() else: parallel_net = get_ff() f = residual_fn_wrapper(attn) g = residual_fn_wrapper(parallel_net) blocks.append(nn.ModuleList([f, g])) self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout=layer_dropout, reverse_thres=reverse_thres, send_signal=True)
def __init__(self, dim, depth, max_seq_len = None, causal = False, heads = 8, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, non_permutative = False, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., weight_tie = False, ff_glu = False, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 2, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 1, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() layers = nn.ModuleList([]) kv_bucket_size = default(kv_bucket_size, bucket_size) context_bucket_size = default(context_bucket_size, bucket_size) get_attn = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, causal = causal, heads = heads, kv_bucket_size = kv_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, use_simple_sort_net = use_simple_sort_net, n_local_attn_heads = n_local_attn_heads, n_top_buckets = n_top_buckets) get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, glu = ff_glu), along_dim=1) get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys) get_attn_context = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, context_only = True, heads = heads, kv_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = context_n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, n_top_buckets = n_top_buckets) get_ff_context = lambda: FeedForward(dim, dropout = ff_dropout, glu = ff_glu) if weight_tie: get_attn, get_attn_context, get_ff, get_ff_context = map(cache_fn, (get_attn, get_attn_context, get_ff, get_ff_context)) fn_wrapper = partial(PreNorm, nn.LayerNorm, dim) if not use_rezero else ReZero for ind in range(depth): layer_num = ind + 1 use_pkm = layer_num in pkm_layers get_parallel_fn = get_ff if not use_pkm else get_pkm layers.append(nn.ModuleList([ fn_wrapper(get_attn()), fn_wrapper(get_parallel_fn()) ])) if not receives_context: continue layers.append(nn.ModuleList([ fn_wrapper(get_attn_context()), fn_wrapper(get_ff_context()) ])) execute_type = ReversibleSequence if reversible else SequentialSequence attn_context_layer = ((True, False),) if receives_context else tuple() route_attn = ((True, False), *attn_context_layer) * depth route_context = ((False, False), *attn_context_layer) * depth context_route_map = {'context': route_context, 'context_mask': route_context} if receives_context else {} attn_route_map = {'input_mask': route_attn} self.layers = execute_type(layers, args_route = {**context_route_map, **attn_route_map}, layer_dropout = layer_dropout) self.receives_context = receives_context self.max_seq_len = max_seq_len self.pad_to_bucket_size = lcm(bucket_size, kv_bucket_size) self.context_bucket_size = context_bucket_size self.is_fixed_length = use_simple_sort_net and not causal # if not using attention sort and also not causal, force fixed sequence length assert not (self.is_fixed_length and self.max_seq_len is None), 'maximum sequence length must be specified if length is fixed'
def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, causal = False, ff_chunks = 1, ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., attn_dropout = 0., reversible = False, blindspot_size = 1, n_local_attn_heads = 0, local_attn_window_size = 128, receives_context = False, attend_axially = False, pkm_layers = tuple(), pkm_num_keys = 128, linformer_settings = None, context_linformer_settings = None): super().__init__() assert not (causal and exists(linformer_settings)), 'Linformer self attention layer can only be used for non-causal networks' assert not exists(linformer_settings) or isinstance(linformer_settings, LinformerSettings), 'Linformer self-attention settings must be a LinformerSettings namedtuple' assert not exists(context_linformer_settings) or isinstance(context_linformer_settings, LinformerContextSettings), 'Linformer contextual self-attention settings must be a LinformerSettings namedtuple' if type(n_local_attn_heads) is not tuple: n_local_attn_heads = tuple([n_local_attn_heads] * depth) assert len(n_local_attn_heads) == depth, 'local attention heads tuple must have the same length as the depth' assert all([(local_heads <= heads) for local_heads in n_local_attn_heads]), 'number of local attn heads must be less than the maximum number of heads' layers = nn.ModuleList([]) for ind, local_heads in zip(range(depth), n_local_attn_heads): layer_num = ind + 1 use_pkm = layer_num in cast_tuple(pkm_layers) parallel_net = Chunk(ff_chunks, FeedForward(dim), along_dim = 1) if not use_pkm else PKM(dim) if not exists(linformer_settings): attn = SelfAttention(dim, heads, causal, dim_head = dim_head, blindspot_size = blindspot_size, n_local_attn_heads = local_heads, local_attn_window_size = local_attn_window_size, dropout = attn_layer_dropout, attn_dropout= attn_dropout) else: attn = LinformerSelfAttention(dim, max_seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, **linformer_settings._asdict()) layers.append(nn.ModuleList([ PreNorm(dim, attn), PreNorm(dim, parallel_net) ])) if attend_axially: layers.append(nn.ModuleList([ PreNorm(dim, FoldAxially(local_attn_window_size, SelfAttention(dim, heads, causal, dropout = attn_layer_dropout, attn_dropout= attn_dropout))), PreNorm(dim, Chunk(ff_chunks, FeedForward(dim, glu = ff_glu, dropout= ff_dropout), along_dim = 1)) ])) if receives_context: if not exists(context_linformer_settings): attn = SelfAttention(dim, heads, dim_head = dim_head, dropout = attn_layer_dropout, attn_dropout= attn_dropout, receives_context = True) else: attn = LinformerSelfAttention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, **context_linformer_settings._asdict()) layers.append(nn.ModuleList([ PreNorm(dim, attn), PreNorm(dim, Chunk(ff_chunks, FeedForward(dim, glu = ff_glu, dropout= ff_dropout), along_dim = 1)) ])) execute_type = ReversibleSequence if reversible else SequentialSequence axial_layer = ((True, False),) if attend_axially else tuple() attn_context_layer = ((True, False),) if receives_context else tuple() route_attn = ((True, False), *axial_layer, *attn_context_layer) * depth route_context = ((False, False), *axial_layer, *attn_context_layer) * depth context_route_map = {'context': route_context, 'context_mask': route_context} if receives_context else {} attn_route_map = {'input_mask': route_attn, 'pos_emb': route_attn} self.layers = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) self.pad_to_multiple = lcm( 1 if not causal else blindspot_size, 1 if all([(h == 0) for h in n_local_attn_heads]) else local_attn_window_size )
def __init__(self, dim, depth, max_seq_len, heads=8, dim_head=None, window_size=64, local_attn_window_size=256, local_attn_radius_blocks=1, causal=False, weight_tie=False, attn_dropout=0., ff_dropout=0., attn_layer_dropout=0., layer_dropout=0., n_local_attn_heads=0, ff_glu=False, reversible=False, ff_chunks=1, kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, _register_kmeans_update=False, rel_pos_emb=True, pkm_layers=tuple(), pkm_num_keys=128, moe_layers=tuple(), moe_num_experts=4, moe_loss_coef=1e-2, num_mem_kv=0, shared_qk=None, context_shared_qk=False, use_rezero=False, use_scale_norm=False, ff_activation=None): super().__init__() shared_qk = default( shared_qk, causal ) # default to shared qk when causal, due to experimental results if type(n_local_attn_heads) is not tuple: n_local_attn_heads = tuple([n_local_attn_heads] * depth) assert len( n_local_attn_heads ) == depth, 'local attention heads tuple must have the same length as the depth' assert all( [(local_heads <= heads) for local_heads in n_local_attn_heads] ), 'number of local attn heads must be less than the maximum number of heads' layers = nn.ModuleList([]) norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm fn_wrapper = partial(ReZero) if use_rezero else partial( PreNorm, norm_type, dim) get_attn = lambda local_heads: SelfAttention( dim, depth, max_seq_len, heads, local_heads, window_size, causal=causal, dim_head=dim_head, local_attn_window_size=local_attn_window_size, local_attn_radius_blocks=local_attn_radius_blocks, attn_dropout=attn_dropout, dropout=attn_layer_dropout, kmeans_ema_decay=kmeans_ema_decay, commitment_factor=commitment_factor, rel_pos_emb=rel_pos_emb, num_mem_kv=num_mem_kv, shared_qk=shared_qk) get_ff = lambda: Chunk( ff_chunks, FeedForward( dim, dropout=ff_dropout, glu=ff_glu, activation=ff_activation), along_dim=1) get_context_attn = lambda: SelfAttention( dim, depth, max_seq_len, heads, 0, window_size, dim_head=dim_head, local_attn_window_size=local_attn_window_size, local_attn_radius_blocks=local_attn_radius_blocks, attn_dropout=attn_dropout, dropout=attn_layer_dropout, kmeans_ema_decay=kmeans_ema_decay, commitment_factor=commitment_factor, receives_context=True, context_window_size=context_window_size, num_mem_kv=num_mem_kv, shared_qk=context_shared_qk) get_context_ff = lambda: Chunk( ff_chunks, FeedForward( dim, dropout=ff_dropout, glu=ff_glu, activation=ff_activation), along_dim=1) get_pkm = lambda: PKM(dim, num_keys=pkm_num_keys) get_moe = lambda: MoE( dim, num_experts=moe_num_experts, loss_coef=moe_loss_coef) if weight_tie: assert len( set(n_local_attn_heads) ) == 1, 'you can only weight tie if number of local attention heads for all layers is the same' get_attn, get_ff, get_context_attn, get_context_ff, get_pkm, get_moe = map( cache_fn, (get_attn, get_ff, get_context_attn, get_context_ff, get_pkm, get_moe)) for ind, local_heads in zip(range(depth), n_local_attn_heads): layer = ind + 1 use_pkm = layer in cast_tuple(pkm_layers) use_moe = layer in cast_tuple(moe_layers) get_parallel_fn = get_pkm if use_pkm else get_ff get_parallel_fn = get_moe if use_moe else get_parallel_fn layers.append( nn.ModuleList([ fn_wrapper(get_attn(local_heads)), fn_wrapper(get_parallel_fn()) ])) if not receives_context: continue layers.append( nn.ModuleList([ fn_wrapper(get_context_attn()), fn_wrapper(get_context_ff()) ])) execute_type = ReversibleSequence if reversible else SequentialSequence attn_context_layer = ((True, False), ) if receives_context else tuple() route_attn = ((True, False), *attn_context_layer) * depth route_context = ((False, False), *attn_context_layer) * depth context_route_map = { 'context': route_context, 'context_mask': route_context } if receives_context else {} attn_route_map = {'input_mask': route_attn} self.layers = execute_type(layers, args_route={ **attn_route_map, **context_route_map }, layer_dropout=layer_dropout) self._handle = None if _register_kmeans_update: self.register_kmeans_update() has_local_attn = any([num > 0 for num in n_local_attn_heads]) local_attn_window_size = default(local_attn_window_size, window_size) self.pad_to_multiple = local_attn_window_size if has_local_attn else 0
def get_pkm(): return PKM(dim, num_keys=pkm_num_keys)