def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, causal = False, emb_dim = None, reversible = False, ff_chunks = 1, ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., attn_dropout = 0., blindspot_size = 1, n_local_attn_heads = 0, local_attn_window_size = 128, return_embeddings = False, receives_context = False, pkm_layers = tuple(), pkm_num_keys = 128, attend_axially = False, linformer_settings = None, context_linformer_settings = None, use_axial_pos_emb = True, use_rotary_emb = False): assert n_local_attn_heads == 0 or (max_seq_len % local_attn_window_size) == 0, 'max sequence length must be divisible by the local attention window size' super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) if use_rotary_emb: self.pos_emb = FixedPositionalEmbedding(emb_dim, max_seq_len) self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len) elif use_axial_pos_emb: self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape=(math.ceil(max_seq_len / local_attn_window_size), local_attn_window_size)) self.layer_pos_emb = always(None) else: self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) self.layer_pos_emb = always(None) self.transformer = LinearAttentionTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, causal = causal, ff_chunks = ff_chunks, ff_glu = ff_glu, ff_dropout = ff_dropout, attn_layer_dropout = attn_layer_dropout, attn_dropout = attn_dropout, reversible = reversible, blindspot_size = blindspot_size, n_local_attn_heads = n_local_attn_heads, local_attn_window_size = local_attn_window_size, receives_context = receives_context, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys, attend_axially = attend_axially, linformer_settings = linformer_settings, context_linformer_settings = context_linformer_settings) if emb_dim != dim: self.transformer = ProjectInOut(self.transformer, emb_dim, dim, project_out = not return_embeddings) self.norm = nn.LayerNorm(dim) self.out = nn.Linear(emb_dim, num_tokens) if not return_embeddings else nn.Identity()
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim) self.pos_emb = Always(0) self.layer_pos_emb = Always(None) if rotary_emb: self.layer_pos_emb = FixedPositionalEmbedding(dim_head) elif absolute_position_emb: self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) elif fixed_position_emb: self.pos_emb = FixedPositionalEmbedding(emb_dim) else: axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size)) self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape) self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) self.norm = nn.LayerNorm(dim) if return_embeddings: self.out = Identity() return self.out = nn.Sequential( nn.Linear(dim, emb_dim) if emb_dim != dim else Identity(), nn.Linear(emb_dim, num_tokens) if not weight_tie_embedding else MatrixMultiply(self.token_emb.weight, transpose=True, normalize=True) )
def __init__(self, emb_sz: int, d_emb: int, max_seq_len: int = 512, dropout: float = 0., pos_enc: str = 'absolute', axial_shape: Tuple = None, axial_emb_dims: Tuple = None): store_attr('d_emb') self.scale = d_emb**0.5 self.std = 0.02 # fairseq: d_emb ** -0.5, fastai: 0.01 self.emb = nn.Embedding(emb_sz, d_emb) self.dropout = nn.Dropout(dropout) if pos_enc == 'absolute': self.pos_enc = AbsolutePositionalEmbedding(d_emb, max_seq_len) elif pos_enc == 'fixed': self.pos_enc = FixedPositionalEmbedding(d_emb) elif pos_enc == 'axial': assert axial_shape is not None assert reduce(mul, axial_shape) == max_seq_len axial_emb_dims = default(axial_emb_dims, get_axial_dims(d_emb, len(axial_shape))) self.pos_enc = AxialPositionalEmbedding(d_emb, axial_shape, axial_emb_dims) self._init()
def __init__( self, *, num_tokens, max_seq_len, dim, depth, heads, dim_head = 64, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, feature_redraw_interval = 1000, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), use_scalenorm = False, use_rezero = False, cross_attend = False, no_projection = False, tie_embed = False, rotary_position_emb = True, axial_position_emb = False, axial_position_shape = None, auto_check_redraw = True, qkv_bias = False, attn_out_bias = False, shift_tokens = False ): super().__init__() local_attn_heads = cast_tuple(local_attn_heads) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, dim) if rotary_position_emb: self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len) elif axial_position_emb: axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / 64), 64)) self.pos_emb = AxialPositionalEmbedding(dim, axial_position_shape) self.layer_pos_emb = Always(None) else: self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) self.layer_pos_emb = Always(None) self.dropout = nn.Dropout(emb_dropout) self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.norm = nn.LayerNorm(dim) self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
def build(self): # to be further set # breakpoint() self.image_feature_module = build_image_encoder( self.config.image_feature_processor, direct_features=True ) if self.config.concate_trace: self.trace_feature_module = build_encoder(self.config.trace_feature_encoder) if self.config.base_model_name == "bert-base-uncased": self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained( "bert-base-uncased", "bert-base-uncased" ) elif self.config.base_model_name == "2layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.max_position_embeddings = 1090 config_encoder.num_hidden_layers = 2 config_decoder.num_hidden_layers = 2 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) elif self.config.base_model_name == "3layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.num_hidden_layers = 3 config_decoder.num_hidden_layers = 3 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) if self.config.loop_contrastive: self.trace_caption_contrastive = TraceCaptionContrastiveModel( self.config.tc_contrastive_aggregate_method ) if ( hasattr(self.config, "pretrans_attention") and self.config.pretrans_attention ): # import ipdb; ipdb.set_trace() tempconf = self.encoderdecoder.config.encoder num_heads = tempconf.num_attention_heads num_layers = tempconf.num_hidden_layers self.attention_trans = AttentionTransform(num_layers, num_heads, 100) self.BOS_ID = 101 self.vae = OpenAIDiscreteVAE() image_code_dim = 768 image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers) self.image_seq_len = image_fmap_size ** 2 self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim) self.image_pos_emb = AxialPositionalEmbedding( image_code_dim, axial_shape=(image_fmap_size, image_fmap_size) )
def __init__(self, num_tokens, dim, max_seq_len, depth, heads = 8, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, causal = False, non_permutative = True, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_glu = False, return_embeddings = False, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., weight_tie = False, emb_dim = None, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 0, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 2, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.to_token_emb = nn.Embedding(num_tokens, emb_dim) self.pos_emb = nn.Embedding(max_seq_len, emb_dim) self.axial_pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape = (max_seq_len // bucket_size, bucket_size)) self.sinkhorn_transformer = SinkhornTransformer(dim, depth, max_seq_len = max_seq_len, causal = causal, heads = heads, bucket_size = bucket_size, kv_bucket_size = kv_bucket_size, context_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, reversible = reversible, ff_chunks = ff_chunks, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, weight_tie = weight_tie, ff_glu = ff_glu, use_simple_sort_net = use_simple_sort_net, receives_context = receives_context, context_n_sortcut = context_n_sortcut, n_local_attn_heads = n_local_attn_heads, use_rezero = use_rezero, n_top_buckets = n_top_buckets, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) if emb_dim != dim: self.sinkhorn_transformer = ProjectInOut(self.sinkhorn_transformer, emb_dim, dim, project_out =(not return_embeddings)) self.to_logits = identity if return_embeddings else nn.Linear(emb_dim, num_tokens)
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = None, window_size = 64, local_attn_window_size = None, causal = False, emb_dim = None, weight_tie = False, attn_dropout = 0., ff_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, return_embeddings = False, n_local_attn_heads = 0, reversible = False, ff_chunks = 1, kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, _register_kmeans_update = True, pkm_layers = tuple(), pkm_num_keys = 128, moe_layers = tuple(), moe_num_experts = 4, moe_loss_coef = 1e-2, num_mem_kv = 0, shared_qk = None, context_shared_qk = False): super().__init__() assert (max_seq_len % window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster' emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) self.axial_pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape=(max_seq_len // window_size, window_size)) self.routing_transformer = RoutingTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, window_size = window_size, local_attn_window_size = local_attn_window_size, causal = causal, weight_tie = weight_tie, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, n_local_attn_heads = n_local_attn_heads, ff_glu = ff_glu, reversible = reversible, ff_chunks = ff_chunks, kmeans_ema_decay = kmeans_ema_decay, receives_context = receives_context, context_window_size = context_window_size, rel_pos_emb = rel_pos_emb, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys, moe_layers = moe_layers, moe_num_experts = moe_num_experts, moe_loss_coef = moe_loss_coef, num_mem_kv = num_mem_kv, shared_qk = shared_qk, context_shared_qk = context_shared_qk, _register_kmeans_update = _register_kmeans_update) if emb_dim != dim: self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out = not return_embeddings) self.out = nn.Linear(emb_dim, num_tokens) if not return_embeddings else identity
def __init__(self, emb_sz, dim, max_seq_len=512, dropout=0., pos_enc='absolute', axial_shape=None, axial_emb_dims=None): super().__init__() self.scale = dim**0.5 self.emb = nn.Embedding(emb_sz, dim) if pos_enc == 'absolute': self.pos_enc = AbsolutePositionalEmbedding(dim, max_seq_len) elif pos_enc == 'fixed': self.pos_enc = FixedPositionalEmbedding(dim) elif pos_enc == 'axial': assert axial_shape is not None assert reduce(mul, axial_shape) == max_seq_len axial_emb_dims = default(axial_emb_dims, get_axial_dims(dim, len(axial_shape))) self.pos_enc = AxialPositionalEmbedding(dim, axial_shape, axial_emb_dims) self.dropout = nn.Dropout(dropout) self._init()
def __init__( self, *, dim, vae, num_text_tokens=10000, text_seq_len=256, depth, heads=8, dim_head=64, reversible=False, attn_dropout=0., ff_dropout=0, sparse_attn=False, attn_types=None, loss_img_weight=7, ): super().__init__() assert isinstance( vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens image_fmap_size = (vae.image_size // (2**vae.num_layers)) image_seq_len = image_fmap_size**2 num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos> self.image_pos_emb = AxialPositionalEmbedding( dim, axial_shape=(image_fmap_size, image_fmap_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens self.total_tokens = total_tokens self.total_seq_len = seq_len self.vae = vae self.transformer = Transformer(dim=dim, causal=True, seq_len=seq_len, depth=depth, heads=heads, dim_head=dim_head, reversible=reversible, attn_dropout=attn_dropout, ff_dropout=ff_dropout, attn_types=attn_types, image_fmap_size=image_fmap_size, sparse_attn=sparse_attn) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = (((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))) self.register_buffer('logits_mask', logits_mask, persistent=False) self.loss_img_weight = loss_img_weight
def __init__( self, *, dim, vae, num_text_tokens = 10000, text_seq_len = 256, depth, heads = 8, dim_head = 64, reversible = False, attn_dropout = 0., ff_dropout = 0, sparse_attn = False ): super().__init__() assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2 self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len, dim) self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens self.vae = vae if exists(self.vae): self.vae = vae self.image_emb = vae.codebook self.transformer = Transformer( dim = dim, causal = True, seq_len = seq_len, depth = depth, heads = heads, dim_head = dim_head, reversible = reversible, attn_dropout = attn_dropout, ff_dropout = ff_dropout, sparse_attn = sparse_attn ) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = ( ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) ) self.register_buffer('logits_mask', logits_mask)
def __init__(self, num_tokens, dim, depth, max_seq_len, heads=8, dim_head=None, causal=False, emb_dim=None, one_kv_head=False, reversible=False, ff_chunks=1, ff_glu=False, ff_dropout=0., attn_layer_dropout=0., attn_dropout=0., blindspot_size=1, n_local_attn_heads=0, local_attn_window_size=128, return_embeddings=False, receives_context=False, pkm_layers=tuple(), pkm_num_keys=128, attend_axially=False, linformer_settings=None, context_linformer_settings=None): assert ( max_seq_len % local_attn_window_size ) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster' super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) self.axial_pos_emb = AxialPositionalEmbedding( emb_dim, axial_shape=(max_seq_len // local_attn_window_size, local_attn_window_size)) self.transformer = LinearAttentionTransformer( dim, depth, max_seq_len, heads=heads, dim_head=dim_head, causal=causal, one_kv_head=one_kv_head, ff_chunks=ff_chunks, ff_glu=ff_glu, ff_dropout=ff_dropout, attn_layer_dropout=attn_layer_dropout, attn_dropout=attn_dropout, reversible=reversible, blindspot_size=blindspot_size, n_local_attn_heads=n_local_attn_heads, local_attn_window_size=local_attn_window_size, receives_context=receives_context, pkm_layers=pkm_layers, pkm_num_keys=pkm_num_keys, attend_axially=attend_axially, linformer_settings=linformer_settings, context_linformer_settings=context_linformer_settings) if emb_dim != dim: self.transformer = ProjectInOut(self.transformer, emb_dim, dim, project_out=not return_embeddings) self.out = nn.Linear( emb_dim, num_tokens) if not return_embeddings else nn.Identity()
def __init__(self, num_tokens, dim, depth, max_seq_len, heads=8, dim_head=None, window_size=64, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, emb_dim=None, weight_tie=False, attn_dropout=0., ff_dropout=0., attn_layer_dropout=0., layer_dropout=0., ff_mult=4, ff_activation=None, ff_glu=False, return_embeddings=False, n_local_attn_heads=0, reversible=False, ff_chunks=1, kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, _register_kmeans_update=True, pkm_layers=tuple(), pkm_num_keys=128, moe_layers=tuple(), moe_num_experts=4, moe_loss_coef=1e-2, num_mem_kv=0, shared_qk=None, context_shared_qk=False, use_rezero=False, use_scale_norm=False, tie_embedding=False, use_absolute_pos_emb=False, return_context=False): super().__init__() assert ( max_seq_len % window_size ) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster' emb_dim = default(emb_dim, dim) self.emb_dim = emb_dim self.num_tokens = num_tokens self.dim = dim self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) nn.init.normal_(self.token_emb.weight, std=0.02) self.pos_emb = AxialPositionalEmbedding( emb_dim, axial_shape=(max_seq_len // window_size, window_size) ) if not use_absolute_pos_emb else AbsolutePositionalEmbedding( emb_dim, max_seq_len) self.routing_transformer = RoutingTransformer( dim, depth, max_seq_len, heads=heads, dim_head=dim_head, window_size=window_size, local_attn_window_size=local_attn_window_size, local_attn_radius_blocks=local_attn_radius_blocks, causal=causal, weight_tie=weight_tie, ff_dropout=ff_dropout, attn_dropout=attn_dropout, attn_layer_dropout=attn_layer_dropout, layer_dropout=layer_dropout, n_local_attn_heads=n_local_attn_heads, ff_glu=ff_glu, reversible=reversible, ff_chunks=ff_chunks, kmeans_ema_decay=kmeans_ema_decay, receives_context=receives_context, context_window_size=context_window_size, rel_pos_emb=rel_pos_emb, pkm_layers=pkm_layers, pkm_num_keys=pkm_num_keys, moe_layers=moe_layers, moe_num_experts=moe_num_experts, moe_loss_coef=moe_loss_coef, num_mem_kv=num_mem_kv, shared_qk=shared_qk, context_shared_qk=context_shared_qk, _register_kmeans_update=_register_kmeans_update, use_rezero=use_rezero, use_scale_norm=use_scale_norm, ff_activation=ff_activation) if emb_dim != dim: self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out=False) self.norm = nn.LayerNorm(dim) if return_embeddings: self.out = nn.Linear(dim, emb_dim) elif return_context: self.out = nn.Identity() elif tie_embedding: self.out = MatrixMultiply(self.token_emb.weight, transpose=True) else: self.out = nn.Linear(dim, num_tokens)