def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_full_attn = False):
        super().__init__()
        self.emb = emb
        self.depth = depth
        self.token_emb = nn.Embedding(num_tokens, emb)
        self.pos_emb = nn.Embedding(max_seq_len, emb)

        get_full_attn = lambda: SelfAttention(emb, heads, causal = causal)
        get_lsh_attn = lambda: LSHSelfAttention(emb, heads, bucket_size, n_hashes, causal = causal, dropout = lsh_dropout, attn_chunks = attn_chunks, random_rotations_per_head = random_rotations_per_head)

        get_attn = get_full_attn if use_full_attn else get_lsh_attn
        get_ff = lambda: FeedForward(emb)

        if weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []
        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        for _ in range(depth):
            attn = get_attn()
            parallel_net = get_attn() if twin_attention else get_ff()

            f = WithNorm(norm_type, emb, attn)
            g = WithNorm(norm_type, emb, parallel_net)

            if not twin_attention and ff_chunks > 1:
                g = Chunk(ff_chunks, g, along_dim = -2)

            blocks.append(ReversibleBlock(f, g, split_along_dim=-1))

        self.layers = ReversibleSequence(nn.ModuleList(blocks))
        self.to_logits = nn.Linear(emb, num_tokens)
    def __init__(self, dim, depth, max_seq_len, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_full_attn = False, full_attn_thres = 0, num_mem_kv = 0):
        super().__init__()
        self.dim = dim
        self.depth = depth

        get_attn = lambda: SettableArgs(LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dropout = lsh_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres))
        get_ff = lambda: FeedForward(dim)

        if weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []
        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        for _ in range(depth):
            attn = get_attn()
            parallel_net = get_attn() if twin_attention else get_ff()

            f = WithNorm(norm_type, dim, attn)
            g = WithNorm(norm_type, dim, parallel_net)

            if not twin_attention and ff_chunks > 1:
                g = Chunk(ff_chunks, g, along_dim = -2)

            blocks.append(ReversibleBlock(f, g, split_along_dim=-1, fix_random_seed=True))

        self.layers = ReversibleSequence(nn.ModuleList(blocks))
        self.modules = list(chain(*[[m.f_block.fn, m.g_block.fn] for m in blocks]))
Example #3
0
    def __init__(self,
                 emb,
                 depth,
                 max_seq_len,
                 num_tokens=10000,
                 heads=8,
                 bucket_size=64,
                 n_hashes=8,
                 ff_chunks=100,
                 causal=False,
                 weight_tie=False,
                 lsh_dropout=0.):
        super().__init__()
        self.emb = emb
        self.depth = depth
        self.token_emb = nn.Embedding(num_tokens, emb)
        self.pos_emb = nn.Embedding(max_seq_len, emb)

        get_attn = lambda: LSHSelfAttention(emb,
                                            heads,
                                            bucket_size,
                                            n_hashes,
                                            causal=causal,
                                            dropout=lsh_dropout)
        get_ff = lambda: FeedForward(emb)

        if weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []

        for _ in range(depth):
            attn = get_attn()
            ff_net = get_ff()

            f = WithLayerNorm(emb, attn)
            g = Chunk(ff_chunks, WithLayerNorm(emb, ff_net), along_dim=-2)
            blocks.append(ReversibleBlock(f, g, split_along_dim=-1))

        self.layers = ReversibleSequence(nn.ModuleList(blocks))
        self.to_logits = nn.Linear(emb, num_tokens)