def xmoe2_v1(): """Model incorporating mixture-of-experts and local-attention. ~6B parameters 32 experts in 3 hierarchichal moe layers. Returns: a hparams """ hparams = xmoe2_dense(0) moe.set_default_moe_hparams(hparams) hparams.decoder_layers = ([ "local_att", "local_att", "drd", "att", "drd", "local_att", "local_att", "hmoe" ] * 4)[:-1] hparams.d_ff = 2048 hparams.d_kv = 128 hparams.moe_hidden_size = 32768 hparams.mesh_shape = "b0:4;b1:8" hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" hparams.outer_batch_size = 4 hparams.moe_num_experts = [8, 4] hparams.num_heads = 4 return hparams
def xmoe_top_2(): """Mixture of experts (16 experts).""" hparams = xmoe_dense_4k() moe.set_default_moe_hparams(hparams) hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" return hparams
def xmoe_top_2(): """Mixture of experts (16 experts).""" hparams = xmoe_dense_4k() moe.set_default_moe_hparams(hparams) hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" return hparams
def xmoe_wiki_x32(): """Two-dimensional hierarchical mixture of experts. (8x4 experts) * (16M params/expert) * 6 layers = 3B params Returns: a hparams object. """ hparams = xmoe_wiki_base() moe.set_default_moe_hparams(hparams) hparams.feedforward_layer = "hmoe" hparams.moe_hidden_size = 8192 hparams.mesh_shape = "b0:4;b1:8" hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" hparams.outer_batch_size = 4 hparams.moe_num_experts = [8, 4] return hparams
def mtf_transformer_lm_moe(): """Mixture of experts language model. Compare to mtf_transformer.mtf_transformer_lm_baseline() Run this on 2x2 on languagemodel_lm1b32k_packed for 272000 steps (10 epochs) 900M params. Results on LM1B: params/10^9 log-ppl(per-token) 0.90 TODO(noam): rerun experiment Returns: a hparams """ hparams = mtf_transformer.mtf_transformer_lm_baseline() moe.set_default_moe_hparams(hparams) hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" hparams.feedforward_layer = "moe" return hparams
def xmoe_wiki_x(): """Baseline set of parameters for mixture-of-experts. ~6B parameters Returns: a hparams """ hparams = xmoe_wiki_base(0) moe.set_default_moe_hparams(hparams) hparams.decoder_layers = (["att", "drd", "att", "drd", "att", "hmoe"] * 3 + ["att", "drd", "att", "drd"]) hparams.d_ff = 2048 hparams.d_kv = 128 hparams.moe_hidden_size = 32768 hparams.mesh_shape = "b0:4;b1:8" hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" hparams.outer_batch_size = 4 hparams.moe_num_experts = [8, 4] hparams.num_heads = 4 return hparams
def wiki_2x2_base(): """Set of architectural experiments - language model on wikipedia on a 2x2. 1 epoch = ~180k steps at batch size 32 - we may never finish an epoch! Returns: a hparams """ hparams = mtf_transformer.mtf_transformer_base_lm() hparams.shared_embedding_and_softmax_weights = False # no dropout - dataset is big enough to avoid overfitting. hparams.attention_dropout = 0.0 hparams.relu_dropout = 0.0 hparams.layer_prepostprocess_dropout = 0.0 hparams.max_length = 1024 # 4 sequences per core hparams.batch_size = 32 # We don't use linear decay in these experiments, since we don't want # a sharp jump in quality at the end of the training schedule. # You can insert this once you find the right architecture. hparams.learning_rate_schedule = "rsqrt_decay" hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" # parameters for mixture-of-experts moe.set_default_moe_hparams(hparams) hparams.moe_num_experts = 16 hparams.moe_hidden_size = 8192 hparams.decoder_layers = ["att", "drd"] * 6 hparams.d_model = 1024 hparams.d_ff = 2048 hparams.d_kv = 128 hparams.num_heads = 4 return hparams
def wiki_2x2_base(): """Set of architectural experiments - language model on wikipedia on a 2x2. 1 epoch = ~180k steps at batch size 32 - we may never finish an epoch! Returns: a hparams """ hparams = mtf_transformer.mtf_transformer_base_lm() hparams.shared_embedding_and_softmax_weights = False # no dropout - dataset is big enough to avoid overfitting. hparams.attention_dropout = 0.0 hparams.relu_dropout = 0.0 hparams.layer_prepostprocess_dropout = 0.0 hparams.max_length = 1024 # 4 sequences per core hparams.batch_size = 32 # We don't use linear decay in these experiments, since we don't want # a sharp jump in quality at the end of the training schedule. # You can insert this once you find the right architecture. hparams.learning_rate_schedule = "rsqrt_decay" hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" # parameters for mixture-of-experts moe.set_default_moe_hparams(hparams) hparams.moe_num_experts = 16 hparams.moe_hidden_size = 8192 hparams.decoder_layers = ["att", "drd"] * 6 hparams.d_model = 1024 hparams.d_ff = 2048 hparams.d_kv = 128 hparams.num_heads = 4 return hparams
def xmoe2_v1(): """Model incorporating mixture-of-experts and local-attention. ~6B parameters 32 experts in 3 hierarchichal moe layers. Returns: a hparams """ hparams = xmoe2_dense(0) moe.set_default_moe_hparams(hparams) hparams.decoder_layers = ( ["local_att", "local_att", "drd", "att", "drd", "local_att", "local_att", "hmoe"] * 4)[:-1] hparams.d_ff = 2048 hparams.d_kv = 128 hparams.moe_hidden_size = 32768 hparams.mesh_shape = "b0:4;b1:8" hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" hparams.outer_batch_size = 4 hparams.moe_num_experts = [8, 4] hparams.num_heads = 4 return hparams