示例#1
0
    def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, causal = False, emb_dim = None, reversible = False, ff_chunks = 1, ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., attn_dropout = 0., blindspot_size = 1, n_local_attn_heads = 0, local_attn_window_size = 128, return_embeddings = False, receives_context = False, pkm_layers = tuple(), pkm_num_keys = 128, attend_axially = False, linformer_settings = None, context_linformer_settings = None, use_axial_pos_emb = True, use_rotary_emb = False):
        assert n_local_attn_heads == 0 or (max_seq_len % local_attn_window_size) == 0, 'max sequence length must be divisible by the local attention window size'
        super().__init__()
        emb_dim = default(emb_dim, dim)
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, emb_dim)

        if use_rotary_emb:
            self.pos_emb = FixedPositionalEmbedding(emb_dim, max_seq_len)
            self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)
        elif use_axial_pos_emb:
            self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape=(math.ceil(max_seq_len / local_attn_window_size), local_attn_window_size))
            self.layer_pos_emb = always(None)
        else:
            self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
            self.layer_pos_emb = always(None)

        self.transformer = LinearAttentionTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, causal = causal, ff_chunks = ff_chunks, ff_glu = ff_glu, ff_dropout = ff_dropout, attn_layer_dropout = attn_layer_dropout, attn_dropout = attn_dropout, reversible = reversible, blindspot_size = blindspot_size, n_local_attn_heads = n_local_attn_heads, local_attn_window_size = local_attn_window_size, receives_context = receives_context, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys, attend_axially = attend_axially, linformer_settings = linformer_settings, context_linformer_settings = context_linformer_settings)

        if emb_dim != dim:
            self.transformer = ProjectInOut(self.transformer, emb_dim, dim, project_out = not return_embeddings)

        self.norm = nn.LayerNorm(dim)
        self.out = nn.Linear(emb_dim, num_tokens) if not return_embeddings else nn.Identity()
示例#2
0
    def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
        super().__init__()
        emb_dim = default(emb_dim, dim)
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, emb_dim)

        self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)

        self.pos_emb = Always(0)
        self.layer_pos_emb = Always(None)

        if rotary_emb:
            self.layer_pos_emb = FixedPositionalEmbedding(dim_head)
        elif absolute_position_emb:
            self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
        elif fixed_position_emb:
            self.pos_emb = FixedPositionalEmbedding(emb_dim)
        else:
            axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
            self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)

        self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
        self.norm = nn.LayerNorm(dim)

        if return_embeddings:
            self.out = Identity()
            return

        self.out = nn.Sequential(
            nn.Linear(dim, emb_dim) if emb_dim != dim else Identity(),
            nn.Linear(emb_dim, num_tokens) if not weight_tie_embedding else MatrixMultiply(self.token_emb.weight, transpose=True, normalize=True)
        )
示例#3
0
    def __init__(self,
                 emb_sz: int,
                 d_emb: int,
                 max_seq_len: int = 512,
                 dropout: float = 0.,
                 pos_enc: str = 'absolute',
                 axial_shape: Tuple = None,
                 axial_emb_dims: Tuple = None):
        store_attr('d_emb')
        self.scale = d_emb**0.5
        self.std = 0.02  # fairseq: d_emb ** -0.5, fastai: 0.01
        self.emb = nn.Embedding(emb_sz, d_emb)
        self.dropout = nn.Dropout(dropout)

        if pos_enc == 'absolute':
            self.pos_enc = AbsolutePositionalEmbedding(d_emb, max_seq_len)
        elif pos_enc == 'fixed':
            self.pos_enc = FixedPositionalEmbedding(d_emb)
        elif pos_enc == 'axial':
            assert axial_shape is not None
            assert reduce(mul, axial_shape) == max_seq_len
            axial_emb_dims = default(axial_emb_dims,
                                     get_axial_dims(d_emb, len(axial_shape)))
            self.pos_enc = AxialPositionalEmbedding(d_emb, axial_shape,
                                                    axial_emb_dims)
        self._init()
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        dim,
        depth,
        heads,
        dim_head = 64,
        local_attn_heads = 0,
        local_window_size = 256,
        causal = False,
        ff_mult = 4,
        nb_features = None,
        feature_redraw_interval = 1000,
        reversible = False,
        ff_chunks = 1,
        ff_glu = False,
        emb_dropout = 0.,
        ff_dropout = 0.,
        attn_dropout = 0.,
        generalized_attention = False,
        kernel_fn = nn.ReLU(),
        use_scalenorm = False,
        use_rezero = False,
        cross_attend = False,
        no_projection = False,
        tie_embed = False,
        rotary_position_emb = True,
        axial_position_emb = False,
        axial_position_shape = None,
        auto_check_redraw = True,
        qkv_bias = False,
        attn_out_bias = False,
        shift_tokens = False
    ):
        super().__init__()
        local_attn_heads = cast_tuple(local_attn_heads)

        self.max_seq_len = max_seq_len
        self.token_emb = nn.Embedding(num_tokens, dim)

        if rotary_position_emb:
            self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
            self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)
        elif axial_position_emb:
            axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / 64), 64))
            self.pos_emb = AxialPositionalEmbedding(dim, axial_position_shape)
            self.layer_pos_emb = Always(None)
        else:
            self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
            self.layer_pos_emb = Always(None)

        self.dropout = nn.Dropout(emb_dropout)

        self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
        self.norm = nn.LayerNorm(dim)
        self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
示例#5
0
    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True
        )
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased"
            )
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.max_position_embeddings = 1090
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method
            )
        if (
            hasattr(self.config, "pretrans_attention")
            and self.config.pretrans_attention
        ):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads, 100)
        self.BOS_ID = 101
        self.vae = OpenAIDiscreteVAE()
        image_code_dim = 768
        image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers)
        self.image_seq_len = image_fmap_size ** 2
        self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim)
        self.image_pos_emb = AxialPositionalEmbedding(
            image_code_dim, axial_shape=(image_fmap_size, image_fmap_size)
        )
示例#6
0
    def __init__(self, num_tokens, dim, max_seq_len, depth, heads = 8, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, causal = False, non_permutative = True, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_glu = False, return_embeddings = False, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., weight_tie = False, emb_dim = None, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 0, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 2, pkm_layers = tuple(), pkm_num_keys = 128):
        super().__init__()
        emb_dim = default(emb_dim, dim)
        self.max_seq_len = max_seq_len

        self.to_token_emb = nn.Embedding(num_tokens, emb_dim)
        self.pos_emb = nn.Embedding(max_seq_len, emb_dim)
        self.axial_pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape = (max_seq_len // bucket_size, bucket_size))
        self.sinkhorn_transformer = SinkhornTransformer(dim, depth, max_seq_len = max_seq_len, causal = causal, heads = heads, bucket_size = bucket_size, kv_bucket_size = kv_bucket_size, context_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, reversible = reversible, ff_chunks = ff_chunks, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, weight_tie = weight_tie, ff_glu = ff_glu, use_simple_sort_net = use_simple_sort_net, receives_context = receives_context, context_n_sortcut = context_n_sortcut, n_local_attn_heads = n_local_attn_heads, use_rezero = use_rezero, n_top_buckets = n_top_buckets,  pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)

        if emb_dim != dim:
            self.sinkhorn_transformer = ProjectInOut(self.sinkhorn_transformer, emb_dim, dim, project_out =(not return_embeddings))

        self.to_logits = identity if return_embeddings else nn.Linear(emb_dim, num_tokens)
    def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = None, window_size = 64, local_attn_window_size = None, causal = False, emb_dim = None, weight_tie = False, attn_dropout = 0., ff_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, return_embeddings = False, n_local_attn_heads = 0, reversible = False, ff_chunks = 1, kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, _register_kmeans_update = True, pkm_layers = tuple(), pkm_num_keys = 128, moe_layers = tuple(), moe_num_experts = 4, moe_loss_coef = 1e-2, num_mem_kv = 0, shared_qk = None, context_shared_qk = False):
        super().__init__()
        assert (max_seq_len % window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
        emb_dim = default(emb_dim, dim)
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        self.axial_pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape=(max_seq_len // window_size, window_size))
        self.routing_transformer = RoutingTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, window_size = window_size, local_attn_window_size = local_attn_window_size, causal = causal, weight_tie = weight_tie, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, n_local_attn_heads = n_local_attn_heads, ff_glu = ff_glu, reversible = reversible, ff_chunks = ff_chunks, kmeans_ema_decay = kmeans_ema_decay, receives_context = receives_context, context_window_size = context_window_size, rel_pos_emb = rel_pos_emb, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys,  moe_layers = moe_layers, moe_num_experts = moe_num_experts, moe_loss_coef = moe_loss_coef, num_mem_kv = num_mem_kv, shared_qk = shared_qk, context_shared_qk = context_shared_qk, _register_kmeans_update = _register_kmeans_update)

        if emb_dim != dim:
            self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out = not return_embeddings)

        self.out = nn.Linear(emb_dim, num_tokens) if not return_embeddings else identity
示例#8
0
 def __init__(self, emb_sz, dim, max_seq_len=512, dropout=0., pos_enc='absolute',
              axial_shape=None, axial_emb_dims=None):
     super().__init__()
     self.scale = dim**0.5
     self.emb = nn.Embedding(emb_sz, dim)
     if pos_enc == 'absolute':
         self.pos_enc = AbsolutePositionalEmbedding(dim, max_seq_len)
     elif pos_enc == 'fixed':
         self.pos_enc = FixedPositionalEmbedding(dim)
     elif pos_enc == 'axial':
         assert axial_shape is not None
         assert reduce(mul, axial_shape) == max_seq_len
         axial_emb_dims = default(axial_emb_dims, get_axial_dims(dim, len(axial_shape)))
         self.pos_enc = AxialPositionalEmbedding(dim, axial_shape, axial_emb_dims)
     self.dropout = nn.Dropout(dropout)
     self._init()
示例#9
0
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens=10000,
        text_seq_len=256,
        depth,
        heads=8,
        dim_head=64,
        reversible=False,
        attn_dropout=0.,
        ff_dropout=0,
        sparse_attn=False,
        attn_types=None,
        loss_img_weight=7,
    ):
        super().__init__()
        assert isinstance(
            vae, (DiscreteVAE, OpenAIDiscreteVAE,
                  VQGanVAE1024)), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2**vae.num_layers))
        image_seq_len = image_fmap_size**2

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        self.text_emb = nn.Embedding(num_text_tokens, dim)
        self.image_emb = nn.Embedding(num_image_tokens, dim)

        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim)  # +1 for <bos>
        self.image_pos_emb = AxialPositionalEmbedding(
            dim, axial_shape=(image_fmap_size, image_fmap_size))

        self.num_text_tokens = num_text_tokens  # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        self.total_seq_len = seq_len

        self.vae = vae

        self.transformer = Transformer(dim=dim,
                                       causal=True,
                                       seq_len=seq_len,
                                       depth=depth,
                                       heads=heads,
                                       dim_head=dim_head,
                                       reversible=reversible,
                                       attn_dropout=attn_dropout,
                                       ff_dropout=ff_dropout,
                                       attn_types=attn_types,
                                       image_fmap_size=image_fmap_size,
                                       sparse_attn=sparse_attn)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (((seq_range >= text_seq_len) &
                        (logits_range < num_text_tokens)) |
                       ((seq_range < text_seq_len) &
                        (logits_range >= num_text_tokens)))

        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight
示例#10
0
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False
    ):
        super().__init__()
        assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2

        self.text_emb = nn.Embedding(num_text_tokens, dim)
        self.image_emb = nn.Embedding(num_image_tokens, dim)

        self.text_pos_emb = nn.Embedding(text_seq_len, dim)
        self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size))

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
        self.total_tokens = total_tokens
        
        self.vae = vae
        if exists(self.vae):
            self.vae = vae
            self.image_emb = vae.codebook

        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            sparse_attn = sparse_attn
        )

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
            ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
            ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
        )

        self.register_buffer('logits_mask', logits_mask)
    def __init__(self,
                 num_tokens,
                 dim,
                 depth,
                 max_seq_len,
                 heads=8,
                 dim_head=None,
                 causal=False,
                 emb_dim=None,
                 one_kv_head=False,
                 reversible=False,
                 ff_chunks=1,
                 ff_glu=False,
                 ff_dropout=0.,
                 attn_layer_dropout=0.,
                 attn_dropout=0.,
                 blindspot_size=1,
                 n_local_attn_heads=0,
                 local_attn_window_size=128,
                 return_embeddings=False,
                 receives_context=False,
                 pkm_layers=tuple(),
                 pkm_num_keys=128,
                 attend_axially=False,
                 linformer_settings=None,
                 context_linformer_settings=None):
        assert (
            max_seq_len % local_attn_window_size
        ) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
        super().__init__()
        emb_dim = default(emb_dim, dim)
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        self.axial_pos_emb = AxialPositionalEmbedding(
            emb_dim,
            axial_shape=(max_seq_len // local_attn_window_size,
                         local_attn_window_size))
        self.transformer = LinearAttentionTransformer(
            dim,
            depth,
            max_seq_len,
            heads=heads,
            dim_head=dim_head,
            causal=causal,
            one_kv_head=one_kv_head,
            ff_chunks=ff_chunks,
            ff_glu=ff_glu,
            ff_dropout=ff_dropout,
            attn_layer_dropout=attn_layer_dropout,
            attn_dropout=attn_dropout,
            reversible=reversible,
            blindspot_size=blindspot_size,
            n_local_attn_heads=n_local_attn_heads,
            local_attn_window_size=local_attn_window_size,
            receives_context=receives_context,
            pkm_layers=pkm_layers,
            pkm_num_keys=pkm_num_keys,
            attend_axially=attend_axially,
            linformer_settings=linformer_settings,
            context_linformer_settings=context_linformer_settings)

        if emb_dim != dim:
            self.transformer = ProjectInOut(self.transformer,
                                            emb_dim,
                                            dim,
                                            project_out=not return_embeddings)

        self.out = nn.Linear(
            emb_dim, num_tokens) if not return_embeddings else nn.Identity()
示例#12
0
    def __init__(self,
                 num_tokens,
                 dim,
                 depth,
                 max_seq_len,
                 heads=8,
                 dim_head=None,
                 window_size=64,
                 local_attn_window_size=None,
                 local_attn_radius_blocks=1,
                 causal=False,
                 emb_dim=None,
                 weight_tie=False,
                 attn_dropout=0.,
                 ff_dropout=0.,
                 attn_layer_dropout=0.,
                 layer_dropout=0.,
                 ff_mult=4,
                 ff_activation=None,
                 ff_glu=False,
                 return_embeddings=False,
                 n_local_attn_heads=0,
                 reversible=False,
                 ff_chunks=1,
                 kmeans_ema_decay=0.999,
                 commitment_factor=1e-4,
                 receives_context=False,
                 context_window_size=None,
                 rel_pos_emb=True,
                 _register_kmeans_update=True,
                 pkm_layers=tuple(),
                 pkm_num_keys=128,
                 moe_layers=tuple(),
                 moe_num_experts=4,
                 moe_loss_coef=1e-2,
                 num_mem_kv=0,
                 shared_qk=None,
                 context_shared_qk=False,
                 use_rezero=False,
                 use_scale_norm=False,
                 tie_embedding=False,
                 use_absolute_pos_emb=False,
                 return_context=False):
        super().__init__()
        assert (
            max_seq_len % window_size
        ) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
        emb_dim = default(emb_dim, dim)
        self.emb_dim = emb_dim
        self.num_tokens = num_tokens
        self.dim = dim
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        nn.init.normal_(self.token_emb.weight, std=0.02)

        self.pos_emb = AxialPositionalEmbedding(
            emb_dim, axial_shape=(max_seq_len // window_size, window_size)
        ) if not use_absolute_pos_emb else AbsolutePositionalEmbedding(
            emb_dim, max_seq_len)

        self.routing_transformer = RoutingTransformer(
            dim,
            depth,
            max_seq_len,
            heads=heads,
            dim_head=dim_head,
            window_size=window_size,
            local_attn_window_size=local_attn_window_size,
            local_attn_radius_blocks=local_attn_radius_blocks,
            causal=causal,
            weight_tie=weight_tie,
            ff_dropout=ff_dropout,
            attn_dropout=attn_dropout,
            attn_layer_dropout=attn_layer_dropout,
            layer_dropout=layer_dropout,
            n_local_attn_heads=n_local_attn_heads,
            ff_glu=ff_glu,
            reversible=reversible,
            ff_chunks=ff_chunks,
            kmeans_ema_decay=kmeans_ema_decay,
            receives_context=receives_context,
            context_window_size=context_window_size,
            rel_pos_emb=rel_pos_emb,
            pkm_layers=pkm_layers,
            pkm_num_keys=pkm_num_keys,
            moe_layers=moe_layers,
            moe_num_experts=moe_num_experts,
            moe_loss_coef=moe_loss_coef,
            num_mem_kv=num_mem_kv,
            shared_qk=shared_qk,
            context_shared_qk=context_shared_qk,
            _register_kmeans_update=_register_kmeans_update,
            use_rezero=use_rezero,
            use_scale_norm=use_scale_norm,
            ff_activation=ff_activation)

        if emb_dim != dim:
            self.routing_transformer = ProjectInOut(self.routing_transformer,
                                                    emb_dim,
                                                    dim,
                                                    project_out=False)

        self.norm = nn.LayerNorm(dim)

        if return_embeddings:
            self.out = nn.Linear(dim, emb_dim)
        elif return_context:
            self.out = nn.Identity()
        elif tie_embedding:
            self.out = MatrixMultiply(self.token_emb.weight, transpose=True)
        else:
            self.out = nn.Linear(dim, num_tokens)