def __init__(self, dim, depth, max_seq_len, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_full_attn = False, full_attn_thres = 0, num_mem_kv = 0, one_value_head = False):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.bucket_size = bucket_size
        self.num_mem_kv = num_mem_kv
        self.full_attn_thres = full_attn_thres

        get_attn = lambda: SettableArgs(LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dropout = lsh_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head))
        get_ff = lambda: FeedForward(dim)

        if weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []
        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        for _ in range(depth):
            attn = get_attn()
            parallel_net = get_attn() if twin_attention else get_ff()

            f = WithNorm(norm_type, dim, attn)
            g = WithNorm(norm_type, dim, parallel_net)

            if not twin_attention and ff_chunks > 1:
                g = Chunk(ff_chunks, g, along_dim = -2)

            blocks.append(nn.ModuleList([f, g]))

        self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout)
        self.settables = filter(lambda x: isinstance(x, SettableArgs), self.layers.modules())
Exemple #2
0
    def load_pretrained_gpt2_weights(self, config, pretrained_gpt2_model):
        print("load pretrained gpt2 weights")
        gpt2_layers = pretrained_gpt2_model.transformer.h
        reformer = self.transformer.h

        # get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, add_local_attn_hash = add_local_attn_hash, causal = causal, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head)
        get_attn = lambda: Block(config.n_ctx, config, scale=True)
        # get_ff = lambda: FeedForward(config.n_dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult)
        get_ff = lambda: nn.Linear(config.n_embd, config.n_embd)

        if config.weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []
        norm_type = nn.LayerNorm

        for i in range(len(gpt2_layers)):
            attn = get_attn()
            # parallel_net = get_attn() if twin_attention else get_ff()
            parallel_net = get_ff()

            # load the weight
            attn.load_state_dict(gpt2_layers[i].state_dict())

            f = WithNorm(norm_type, config.n_embd, attn)
            g = WithNorm(norm_type, config.n_embd, parallel_net)

            # if not twin_attention and ff_chunks > 1:
            #     g = Chunk(ff_chunks, g, along_dim = -2)

            blocks.append(nn.ModuleList([f, g]))

        self.transformer.h.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = 0.1, reverse_thres = 0)
Exemple #3
0
    def __init__(self, config):
        super().__init__()
        self.dim = config.n_embd
        self.n_layer = config.n_layer
        self.config = config

        # get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, add_local_attn_hash = add_local_attn_hash, causal = causal, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head)
        get_attn = lambda: Block(config.n_ctx, config, scale=True)
        # get_ff = lambda: FeedForward(config.n_dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult)
        get_ff = lambda: nn.Linear(config.n_embd, config.n_embd)

        if config.weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []
        norm_type = nn.LayerNorm

        for _ in range(self.n_layer):
            attn = get_attn()
            # parallel_net = get_attn() if twin_attention else get_ff()
            parallel_net = get_ff()

            f = WithNorm(norm_type, self.dim, attn)
            g = WithNorm(norm_type, self.dim, parallel_net)

            # if not twin_attention and ff_chunks > 1:
            #     g = Chunk(ff_chunks, g, along_dim = -2)

            blocks.append(nn.ModuleList([f, g]))

        self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = 0.1, reverse_thres = 0)
Exemple #4
0
    def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128, kernelized_lsh = False):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.bucket_size = bucket_size
        self.num_mem_kv = num_mem_kv

        self.twin_attention = twin_attention
        self.full_attn_thres = full_attn_thres
        
        
        get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dim_head = dim_head, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, kernelized_lsh = kernelized_lsh)
        get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult, glu = ff_glu), along_dim = -2)
        get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys)

        if weight_tie:
            get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm))

        blocks = []

        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim)

        for ind in range(depth):
            layer_num = ind + 1
            use_pkm = layer_num in cast_tuple(pkm_layers)
            parallel_net = None

            attn = get_attn()

            if use_pkm:
                parallel_net = get_pkm()
            elif twin_attention:
                parallel_net = get_attn()
            else:
                parallel_net = get_ff()

            f = residual_fn_wrapper(attn)
            g = residual_fn_wrapper(parallel_net)

            blocks.append(nn.ModuleList([f, g]))

        self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout, reverse_thres = reverse_thres, send_signal = True)
    def __init__(self,
                 dim,
                 depth,
                 max_seq_len,
                 heads=8,
                 bucket_size=64,
                 n_hashes=8,
                 ff_chunks=100,
                 attn_chunks=None,
                 causal=False,
                 weight_tie=False,
                 lsh_dropout=0.,
                 ff_dropout=0.,
                 ff_activation=None,
                 ff_mult=4,
                 ff_glu=False,
                 post_attn_dropout=0.,
                 layer_dropout=0.,
                 lsh_attend_across_buckets=True,
                 lsh_allow_duplicate_attention=True,
                 random_rotations_per_head=False,
                 twin_attention=False,
                 use_scale_norm=False,
                 use_rezero=False,
                 use_full_attn=False,
                 full_attn_thres=0,
                 reverse_thres=0,
                 num_mem_kv=0,
                 one_value_head=False,
                 n_local_attn_heads=0):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.bucket_size = bucket_size
        self.num_mem_kv = num_mem_kv

        self.twin_attention = twin_attention
        self.full_attn_thres = full_attn_thres

        get_attn = lambda: LSHSelfAttention(
            dim,
            heads,
            bucket_size,
            n_hashes,
            causal=causal,
            dropout=lsh_dropout,
            post_attn_dropout=post_attn_dropout,
            attn_chunks=attn_chunks,
            allow_duplicate_attention=lsh_allow_duplicate_attention,
            attend_across_buckets=lsh_attend_across_buckets,
            random_rotations_per_head=random_rotations_per_head,
            num_mem_kv=num_mem_kv,
            use_full_attn=use_full_attn,
            full_attn_thres=full_attn_thres,
            one_value_head=one_value_head,
            n_local_attn_heads=n_local_attn_heads)
        get_ff = lambda: FeedForward(dim,
                                     dropout=ff_dropout,
                                     activation=ff_activation,
                                     mult=ff_mult,
                                     glu=ff_glu)

        if weight_tie:
            get_attn = cache_fn(get_attn)
            get_ff = cache_fn(get_ff)

        blocks = []

        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        residual_fn_wrapper = ReZero if use_rezero else partial(
            PreNorm, norm_type, dim)

        for _ in range(depth):
            attn = get_attn()
            parallel_net = get_attn() if twin_attention else get_ff()

            f = residual_fn_wrapper(attn)
            g = residual_fn_wrapper(parallel_net)

            if not twin_attention and ff_chunks > 1:
                g = Chunk(ff_chunks, g, along_dim=-2)

            blocks.append(nn.ModuleList([f, g]))

        self.layers = ReversibleSequence(nn.ModuleList(blocks),
                                         layer_dropout=layer_dropout,
                                         reverse_thres=reverse_thres,
                                         send_signal=True)