def __init__(self, dim, depth, max_seq_len, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_full_attn = False, full_attn_thres = 0, num_mem_kv = 0, one_value_head = False): super().__init__() self.dim = dim self.depth = depth self.bucket_size = bucket_size self.num_mem_kv = num_mem_kv self.full_attn_thres = full_attn_thres get_attn = lambda: SettableArgs(LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dropout = lsh_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head)) get_ff = lambda: FeedForward(dim) if weight_tie: get_attn = cache_fn(get_attn) get_ff = cache_fn(get_ff) blocks = [] norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm for _ in range(depth): attn = get_attn() parallel_net = get_attn() if twin_attention else get_ff() f = WithNorm(norm_type, dim, attn) g = WithNorm(norm_type, dim, parallel_net) if not twin_attention and ff_chunks > 1: g = Chunk(ff_chunks, g, along_dim = -2) blocks.append(nn.ModuleList([f, g])) self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout) self.settables = filter(lambda x: isinstance(x, SettableArgs), self.layers.modules())
def load_pretrained_gpt2_weights(self, config, pretrained_gpt2_model): print("load pretrained gpt2 weights") gpt2_layers = pretrained_gpt2_model.transformer.h reformer = self.transformer.h # get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, add_local_attn_hash = add_local_attn_hash, causal = causal, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head) get_attn = lambda: Block(config.n_ctx, config, scale=True) # get_ff = lambda: FeedForward(config.n_dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult) get_ff = lambda: nn.Linear(config.n_embd, config.n_embd) if config.weight_tie: get_attn = cache_fn(get_attn) get_ff = cache_fn(get_ff) blocks = [] norm_type = nn.LayerNorm for i in range(len(gpt2_layers)): attn = get_attn() # parallel_net = get_attn() if twin_attention else get_ff() parallel_net = get_ff() # load the weight attn.load_state_dict(gpt2_layers[i].state_dict()) f = WithNorm(norm_type, config.n_embd, attn) g = WithNorm(norm_type, config.n_embd, parallel_net) # if not twin_attention and ff_chunks > 1: # g = Chunk(ff_chunks, g, along_dim = -2) blocks.append(nn.ModuleList([f, g])) self.transformer.h.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = 0.1, reverse_thres = 0)
def __init__(self, config): super().__init__() self.dim = config.n_embd self.n_layer = config.n_layer self.config = config # get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, add_local_attn_hash = add_local_attn_hash, causal = causal, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head) get_attn = lambda: Block(config.n_ctx, config, scale=True) # get_ff = lambda: FeedForward(config.n_dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult) get_ff = lambda: nn.Linear(config.n_embd, config.n_embd) if config.weight_tie: get_attn = cache_fn(get_attn) get_ff = cache_fn(get_ff) blocks = [] norm_type = nn.LayerNorm for _ in range(self.n_layer): attn = get_attn() # parallel_net = get_attn() if twin_attention else get_ff() parallel_net = get_ff() f = WithNorm(norm_type, self.dim, attn) g = WithNorm(norm_type, self.dim, parallel_net) # if not twin_attention and ff_chunks > 1: # g = Chunk(ff_chunks, g, along_dim = -2) blocks.append(nn.ModuleList([f, g])) self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = 0.1, reverse_thres = 0)
def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128, kernelized_lsh = False): super().__init__() self.dim = dim self.depth = depth self.bucket_size = bucket_size self.num_mem_kv = num_mem_kv self.twin_attention = twin_attention self.full_attn_thres = full_attn_thres get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dim_head = dim_head, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, kernelized_lsh = kernelized_lsh) get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult, glu = ff_glu), along_dim = -2) get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys) if weight_tie: get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm)) blocks = [] norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim) for ind in range(depth): layer_num = ind + 1 use_pkm = layer_num in cast_tuple(pkm_layers) parallel_net = None attn = get_attn() if use_pkm: parallel_net = get_pkm() elif twin_attention: parallel_net = get_attn() else: parallel_net = get_ff() f = residual_fn_wrapper(attn) g = residual_fn_wrapper(parallel_net) blocks.append(nn.ModuleList([f, g])) self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout, reverse_thres = reverse_thres, send_signal = True)
def __init__(self, dim, depth, max_seq_len, heads=8, bucket_size=64, n_hashes=8, ff_chunks=100, attn_chunks=None, causal=False, weight_tie=False, lsh_dropout=0., ff_dropout=0., ff_activation=None, ff_mult=4, ff_glu=False, post_attn_dropout=0., layer_dropout=0., lsh_attend_across_buckets=True, lsh_allow_duplicate_attention=True, random_rotations_per_head=False, twin_attention=False, use_scale_norm=False, use_rezero=False, use_full_attn=False, full_attn_thres=0, reverse_thres=0, num_mem_kv=0, one_value_head=False, n_local_attn_heads=0): super().__init__() self.dim = dim self.depth = depth self.bucket_size = bucket_size self.num_mem_kv = num_mem_kv self.twin_attention = twin_attention self.full_attn_thres = full_attn_thres get_attn = lambda: LSHSelfAttention( dim, heads, bucket_size, n_hashes, causal=causal, dropout=lsh_dropout, post_attn_dropout=post_attn_dropout, attn_chunks=attn_chunks, allow_duplicate_attention=lsh_allow_duplicate_attention, attend_across_buckets=lsh_attend_across_buckets, random_rotations_per_head=random_rotations_per_head, num_mem_kv=num_mem_kv, use_full_attn=use_full_attn, full_attn_thres=full_attn_thres, one_value_head=one_value_head, n_local_attn_heads=n_local_attn_heads) get_ff = lambda: FeedForward(dim, dropout=ff_dropout, activation=ff_activation, mult=ff_mult, glu=ff_glu) if weight_tie: get_attn = cache_fn(get_attn) get_ff = cache_fn(get_ff) blocks = [] norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm residual_fn_wrapper = ReZero if use_rezero else partial( PreNorm, norm_type, dim) for _ in range(depth): attn = get_attn() parallel_net = get_attn() if twin_attention else get_ff() f = residual_fn_wrapper(attn) g = residual_fn_wrapper(parallel_net) if not twin_attention and ff_chunks > 1: g = Chunk(ff_chunks, g, along_dim=-2) blocks.append(nn.ModuleList([f, g])) self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout=layer_dropout, reverse_thres=reverse_thres, send_signal=True)