def xmoe_dense_4k(): """Series of architectural experiments on cheap language models. For all of these architectures, we run on languagemodel_lm1b8k_packed for 32000 steps. All log-perplexities are per-token - multiply by 1.298 for per-word Results: model params(M) einsum alltoall mxu-util log-ppl xmoe_dense_4k 30 3.0e12 0 45% 3.31 xmoe_dense_8k 46 4.7e12 0 49% 3.24 xmoe_dense_64k 282 2.8e13 0 3.06 xmoe_top_2 282 4.0e12 3.4e8 36% 3.07 xmoe_top_2_c15 282 4.5e12 4.0e8 38% 3.07 xmoe_2d 282 5.3e12 7.6e8 34% 3.06 Trained at 4x the batch size: xmoe_2d_88 1090 2.1e13 3.0e9 24% 3.07 Note: configurations and code are likely to change without notice. Returns: a hparams """ hparams = mtf_transformer.mtf_transformer_base_lm() hparams.attention_dropout = 0.0 hparams.relu_dropout = 0.0 hparams.layer_prepostprocess_dropout = 0.0 # The following hparams are constant across all these experiments. hparams.batch_size = 128 hparams.d_model = 512 hparams.d_kv = 128 hparams.num_heads = 4 hparams.decoder_layers = ["att", "drd"] * 4 hparams.shared_embedding_and_softmax_weights = False hparams.learning_rate_schedule = "rsqrt_decay" # We will vary the following parameters related to the ffn/moe layers. hparams.d_ff = 4096 hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model" hparams.mesh_shape = "batch:8" return hparams
def wiki_2x2_base(): """Set of architectural experiments - language model on wikipedia on a 2x2. 1 epoch = ~180k steps at batch size 32 - we may never finish an epoch! Returns: a hparams """ hparams = mtf_transformer.mtf_transformer_base_lm() hparams.shared_embedding_and_softmax_weights = False # no dropout - dataset is big enough to avoid overfitting. hparams.attention_dropout = 0.0 hparams.relu_dropout = 0.0 hparams.layer_prepostprocess_dropout = 0.0 hparams.max_length = 1024 # 4 sequences per core hparams.batch_size = 32 # We don't use linear decay in these experiments, since we don't want # a sharp jump in quality at the end of the training schedule. # You can insert this once you find the right architecture. hparams.learning_rate_schedule = "rsqrt_decay" hparams.mesh_shape = "all:8" hparams.layout = "batch:all;experts:all" # parameters for mixture-of-experts moe.set_default_moe_hparams(hparams) hparams.moe_num_experts = 16 hparams.moe_hidden_size = 8192 hparams.decoder_layers = ["att", "drd"] * 6 hparams.d_model = 1024 hparams.d_ff = 2048 hparams.d_kv = 128 hparams.num_heads = 4 return hparams