Example #1
0
    def __init__(self, dim, depth, max_seq_len, heads=8, dim_head=None, bucket_size=64, n_hashes=8, ff_chunks=100,
                 attn_chunks=None,
                 causal=False, weight_tie=False, lsh_dropout=0., ff_dropout=0., ff_activation=None, ff_mult=4,
                 ff_glu=False, post_attn_dropout=0., layer_dropout=0., lsh_attend_across_buckets=True,
                 lsh_allow_duplicate_attention=True, random_rotations_per_head=False, use_scale_norm=False,
                 use_rezero=False, use_full_attn=False, full_attn_thres=0, k_means_hashing=False, reverse_thres=0,
                 num_mem_kv=0,
                 one_value_head=False, n_local_attn_heads=0, pkm_layers=tuple(), pkm_num_keys=128):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.bucket_size = bucket_size
        self.num_mem_kv = num_mem_kv
        self.full_attn_thres = full_attn_thres

        get_attn = lambda: LSHSelfAttention(dim=dim, max_seq_len=max_seq_len, heads=heads, bucket_size=bucket_size,
                                            n_hashes=n_hashes, causal=causal,
                                            dropout=lsh_dropout,
                                            attn_chunks=attn_chunks,
                                            allow_duplicate_attention=lsh_allow_duplicate_attention,
                                            attend_across_buckets=lsh_attend_across_buckets,
                                            random_rotations_per_head=random_rotations_per_head, num_mem_kv=num_mem_kv,
                                            use_full_attn=use_full_attn, full_attn_thres=full_attn_thres,
                                            k_means_hashing=k_means_hashing)
        get_ff = lambda: Chunk(ff_chunks,
                               FeedForward(dim, dropout=ff_dropout, activation=ff_activation, mult=ff_mult, glu=ff_glu),
                               along_dim=-2)
        get_pkm = lambda: PKM(dim, num_keys=pkm_num_keys)

        if weight_tie:
            get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm))

        blocks = []

        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim)

        for ind in range(depth):
            layer_num = ind + 1
            use_pkm = layer_num in cast_tuple(pkm_layers)
            parallel_net = None

            attn = get_attn()

            if use_pkm:
                parallel_net = get_pkm()
            else:
                parallel_net = get_ff()

            f = residual_fn_wrapper(attn)
            g = residual_fn_wrapper(parallel_net)

            blocks.append(nn.ModuleList([f, g]))

        self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout=layer_dropout,
                                         reverse_thres=reverse_thres, send_signal=True)
Example #2
0
    def __init__(self, dim, depth, max_seq_len = None, causal = False, heads = 8, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, non_permutative = False, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., weight_tie = False, ff_glu = False, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 2, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 1,  pkm_layers = tuple(), pkm_num_keys = 128):
        super().__init__()
        layers = nn.ModuleList([])

        kv_bucket_size = default(kv_bucket_size, bucket_size)
        context_bucket_size = default(context_bucket_size, bucket_size)

        get_attn = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, causal = causal, heads = heads, kv_bucket_size = kv_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, use_simple_sort_net = use_simple_sort_net, n_local_attn_heads = n_local_attn_heads, n_top_buckets = n_top_buckets)
        get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, glu = ff_glu), along_dim=1)
        get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys)

        get_attn_context = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, context_only = True, heads = heads, kv_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = context_n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, n_top_buckets = n_top_buckets)
        get_ff_context = lambda: FeedForward(dim, dropout = ff_dropout, glu = ff_glu)

        if weight_tie:
            get_attn, get_attn_context, get_ff, get_ff_context = map(cache_fn, (get_attn, get_attn_context, get_ff, get_ff_context))

        fn_wrapper = partial(PreNorm, nn.LayerNorm, dim) if not use_rezero else ReZero

        for ind in range(depth):
            layer_num = ind + 1
            use_pkm = layer_num in pkm_layers

            get_parallel_fn = get_ff if not use_pkm else get_pkm

            layers.append(nn.ModuleList([
                fn_wrapper(get_attn()),
                fn_wrapper(get_parallel_fn())
            ]))

            if not receives_context:
                continue

            layers.append(nn.ModuleList([
                fn_wrapper(get_attn_context()),
                fn_wrapper(get_ff_context())
            ]))

        execute_type = ReversibleSequence if reversible else SequentialSequence

        attn_context_layer = ((True, False),) if receives_context else tuple()
        route_attn = ((True, False), *attn_context_layer) * depth
        route_context = ((False, False), *attn_context_layer) * depth

        context_route_map = {'context': route_context, 'context_mask': route_context} if receives_context else {}
        attn_route_map = {'input_mask': route_attn}

        self.layers = execute_type(layers, args_route = {**context_route_map, **attn_route_map}, layer_dropout = layer_dropout)
        self.receives_context = receives_context

        self.max_seq_len = max_seq_len
        self.pad_to_bucket_size = lcm(bucket_size, kv_bucket_size)
        self.context_bucket_size = context_bucket_size

        self.is_fixed_length = use_simple_sort_net and not causal

        # if not using attention sort and also not causal, force fixed sequence length
        assert not (self.is_fixed_length and self.max_seq_len is None), 'maximum sequence length must be specified if length is fixed'
Example #3
0
    def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, causal = False, ff_chunks = 1, ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., attn_dropout = 0., reversible = False, blindspot_size = 1, n_local_attn_heads = 0, local_attn_window_size = 128, receives_context = False, attend_axially = False, pkm_layers = tuple(), pkm_num_keys = 128, linformer_settings = None, context_linformer_settings = None):
        super().__init__()
        assert not (causal and exists(linformer_settings)), 'Linformer self attention layer can only be used for non-causal networks'
        assert not exists(linformer_settings) or isinstance(linformer_settings, LinformerSettings), 'Linformer self-attention settings must be a LinformerSettings namedtuple'
        assert not exists(context_linformer_settings) or isinstance(context_linformer_settings, LinformerContextSettings), 'Linformer contextual self-attention settings must be a LinformerSettings namedtuple'

        if type(n_local_attn_heads) is not tuple:
            n_local_attn_heads = tuple([n_local_attn_heads] * depth)

        assert len(n_local_attn_heads) == depth, 'local attention heads tuple must have the same length as the depth'
        assert all([(local_heads <= heads) for local_heads in n_local_attn_heads]), 'number of local attn heads must be less than the maximum number of heads'

        layers = nn.ModuleList([])

        for ind, local_heads in zip(range(depth), n_local_attn_heads):
            layer_num = ind + 1
            use_pkm = layer_num in cast_tuple(pkm_layers)

            parallel_net = Chunk(ff_chunks, FeedForward(dim), along_dim = 1) if not use_pkm else PKM(dim)

            if not exists(linformer_settings):
                attn = SelfAttention(dim, heads, causal, dim_head = dim_head, blindspot_size = blindspot_size, n_local_attn_heads = local_heads, local_attn_window_size = local_attn_window_size, dropout = attn_layer_dropout, attn_dropout= attn_dropout)
            else:
                attn = LinformerSelfAttention(dim, max_seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, **linformer_settings._asdict())

            layers.append(nn.ModuleList([
                PreNorm(dim, attn),
                PreNorm(dim, parallel_net)
            ]))

            if attend_axially:
                layers.append(nn.ModuleList([
                    PreNorm(dim, FoldAxially(local_attn_window_size, SelfAttention(dim, heads, causal, dropout = attn_layer_dropout, attn_dropout= attn_dropout))),
                    PreNorm(dim, Chunk(ff_chunks, FeedForward(dim, glu = ff_glu, dropout= ff_dropout), along_dim = 1))
                ]))

            if receives_context:
                if not exists(context_linformer_settings):
                    attn = SelfAttention(dim, heads, dim_head = dim_head, dropout = attn_layer_dropout, attn_dropout= attn_dropout, receives_context = True)
                else:
                    attn = LinformerSelfAttention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, **context_linformer_settings._asdict())

                layers.append(nn.ModuleList([
                    PreNorm(dim, attn),
                    PreNorm(dim, Chunk(ff_chunks, FeedForward(dim, glu = ff_glu, dropout= ff_dropout), along_dim = 1))
                ]))

        execute_type = ReversibleSequence if reversible else SequentialSequence

        axial_layer = ((True, False),) if attend_axially else tuple()
        attn_context_layer = ((True, False),) if receives_context else tuple()
        route_attn = ((True, False), *axial_layer, *attn_context_layer) * depth
        route_context = ((False, False), *axial_layer, *attn_context_layer) * depth

        context_route_map = {'context': route_context, 'context_mask': route_context} if receives_context else {}
        attn_route_map = {'input_mask': route_attn, 'pos_emb': route_attn}
        self.layers = execute_type(layers, args_route = {**attn_route_map, **context_route_map})

        self.pad_to_multiple = lcm(
            1 if not causal else blindspot_size,
            1 if all([(h == 0) for h in n_local_attn_heads]) else local_attn_window_size
        )
    def __init__(self,
                 dim,
                 depth,
                 max_seq_len,
                 heads=8,
                 dim_head=None,
                 window_size=64,
                 local_attn_window_size=256,
                 local_attn_radius_blocks=1,
                 causal=False,
                 weight_tie=False,
                 attn_dropout=0.,
                 ff_dropout=0.,
                 attn_layer_dropout=0.,
                 layer_dropout=0.,
                 n_local_attn_heads=0,
                 ff_glu=False,
                 reversible=False,
                 ff_chunks=1,
                 kmeans_ema_decay=0.999,
                 commitment_factor=1e-4,
                 receives_context=False,
                 context_window_size=None,
                 _register_kmeans_update=False,
                 rel_pos_emb=True,
                 pkm_layers=tuple(),
                 pkm_num_keys=128,
                 moe_layers=tuple(),
                 moe_num_experts=4,
                 moe_loss_coef=1e-2,
                 num_mem_kv=0,
                 shared_qk=None,
                 context_shared_qk=False,
                 use_rezero=False,
                 use_scale_norm=False,
                 ff_activation=None):
        super().__init__()
        shared_qk = default(
            shared_qk, causal
        )  # default to shared qk when causal, due to experimental results

        if type(n_local_attn_heads) is not tuple:
            n_local_attn_heads = tuple([n_local_attn_heads] * depth)

        assert len(
            n_local_attn_heads
        ) == depth, 'local attention heads tuple must have the same length as the depth'
        assert all(
            [(local_heads <= heads) for local_heads in n_local_attn_heads]
        ), 'number of local attn heads must be less than the maximum number of heads'

        layers = nn.ModuleList([])

        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        fn_wrapper = partial(ReZero) if use_rezero else partial(
            PreNorm, norm_type, dim)

        get_attn = lambda local_heads: SelfAttention(
            dim,
            depth,
            max_seq_len,
            heads,
            local_heads,
            window_size,
            causal=causal,
            dim_head=dim_head,
            local_attn_window_size=local_attn_window_size,
            local_attn_radius_blocks=local_attn_radius_blocks,
            attn_dropout=attn_dropout,
            dropout=attn_layer_dropout,
            kmeans_ema_decay=kmeans_ema_decay,
            commitment_factor=commitment_factor,
            rel_pos_emb=rel_pos_emb,
            num_mem_kv=num_mem_kv,
            shared_qk=shared_qk)
        get_ff = lambda: Chunk(
            ff_chunks,
            FeedForward(
                dim, dropout=ff_dropout, glu=ff_glu, activation=ff_activation),
            along_dim=1)
        get_context_attn = lambda: SelfAttention(
            dim,
            depth,
            max_seq_len,
            heads,
            0,
            window_size,
            dim_head=dim_head,
            local_attn_window_size=local_attn_window_size,
            local_attn_radius_blocks=local_attn_radius_blocks,
            attn_dropout=attn_dropout,
            dropout=attn_layer_dropout,
            kmeans_ema_decay=kmeans_ema_decay,
            commitment_factor=commitment_factor,
            receives_context=True,
            context_window_size=context_window_size,
            num_mem_kv=num_mem_kv,
            shared_qk=context_shared_qk)
        get_context_ff = lambda: Chunk(
            ff_chunks,
            FeedForward(
                dim, dropout=ff_dropout, glu=ff_glu, activation=ff_activation),
            along_dim=1)
        get_pkm = lambda: PKM(dim, num_keys=pkm_num_keys)
        get_moe = lambda: MoE(
            dim, num_experts=moe_num_experts, loss_coef=moe_loss_coef)

        if weight_tie:
            assert len(
                set(n_local_attn_heads)
            ) == 1, 'you can only weight tie if number of local attention heads for all layers is the same'
            get_attn, get_ff, get_context_attn, get_context_ff, get_pkm, get_moe = map(
                cache_fn, (get_attn, get_ff, get_context_attn, get_context_ff,
                           get_pkm, get_moe))

        for ind, local_heads in zip(range(depth), n_local_attn_heads):
            layer = ind + 1
            use_pkm = layer in cast_tuple(pkm_layers)
            use_moe = layer in cast_tuple(moe_layers)

            get_parallel_fn = get_pkm if use_pkm else get_ff
            get_parallel_fn = get_moe if use_moe else get_parallel_fn

            layers.append(
                nn.ModuleList([
                    fn_wrapper(get_attn(local_heads)),
                    fn_wrapper(get_parallel_fn())
                ]))

            if not receives_context:
                continue

            layers.append(
                nn.ModuleList([
                    fn_wrapper(get_context_attn()),
                    fn_wrapper(get_context_ff())
                ]))

        execute_type = ReversibleSequence if reversible else SequentialSequence

        attn_context_layer = ((True, False), ) if receives_context else tuple()
        route_attn = ((True, False), *attn_context_layer) * depth
        route_context = ((False, False), *attn_context_layer) * depth

        context_route_map = {
            'context': route_context,
            'context_mask': route_context
        } if receives_context else {}
        attn_route_map = {'input_mask': route_attn}
        self.layers = execute_type(layers,
                                   args_route={
                                       **attn_route_map,
                                       **context_route_map
                                   },
                                   layer_dropout=layer_dropout)

        self._handle = None
        if _register_kmeans_update:
            self.register_kmeans_update()

        has_local_attn = any([num > 0 for num in n_local_attn_heads])
        local_attn_window_size = default(local_attn_window_size, window_size)
        self.pad_to_multiple = local_attn_window_size if has_local_attn else 0
Example #5
0
 def get_pkm():
     return PKM(dim, num_keys=pkm_num_keys)