def xmoe_wiki_base():
  """Series of architectural experiments on wikipedia text.

  For all of these architectures, we run on languagemodel_wiki_noref_v8k_l1k
  for 3 epochs.  (training set has ~7390100 sequences each of length 1024)
  1 epoch = 115000 steps at batch_size=64

  Results:
  model             params(M)  einsum  alltoall  mxu-util  log-ppl(1ep) (3ep)

  Note: configurations and code are likely to change without notice.

  Returns:
    a hparams
  """
  hparams = mtf_transformer.mtf_transformer_base()

  # The following hparams are constant across all these experiments.
  hparams.label_smoothing = 0.0
  hparams.max_length = 1024
  hparams.batch_size = 64
  hparams.d_model = 1024
  hparams.d_kv = 128
  hparams.num_heads = 8
  hparams.shared_embedding_and_softmax_weights = False
  hparams.learning_rate_decay_steps = 115000

  # We will vary the following parameters related to the ffn/moe layers.
  hparams.feedforward_layer = "dense_relu_dense"
  hparams.d_ff = 8192
  hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
  hparams.mesh_shape = "batch:32"
  return hparams
def xmoe_dense_4k():
  """Series of architectural experiments on cheap language models.

  For all of these architectures, we run on languagemodel_lm1b8k_packed
  for 32k-96 steps (1-3 epochs) on one TPU (8 cores).

  All log-perplexities are per-token - multiply by 1.298 for per-word

  Results:
  model             params(M)  einsum  alltoall  mxu-util  log-ppl(1ep) (3ep)
  xmoe_dense_4k     30         3.0e12  0         45%        3.31
  xmoe_dense_8k     46         4.7e12  0         49%        3.24
  xmoe_dense_64k    282        2.8e13  0                    3.06
  xmoe_top_2        282        4.0e12  3.4e8     36%        3.07
  xmoe_top_2_c15    282        4.5e12  4.0e8     38%        3.07
  xmoe_2d           282        5.3e12  7.6e8     34%        3.06

  Trained at 4x the batch size:
  xmoe_2d_88        1090       2.1e13  3.0e9     24%

  Note: configurations and code are likely to change without notice.

  Returns:
    a hparams
  """
  hparams = mtf_transformer.mtf_transformer_base()

  # The following hparams are constant across all these experiments.
  hparams.label_smoothing = 0.0
  hparams.batch_size = 128
  hparams.d_model = 512
  hparams.d_kv = 128
  hparams.num_heads = 4
  hparams.num_decoder_layers = 4
  hparams.shared_embedding_and_softmax_weights = False
  hparams.learning_rate_schedule = "rsqrt_decay"

  # We will vary the following parameters related to the ffn/moe layers.
  hparams.feedforward_layer = "dense_relu_dense"
  hparams.d_ff = 4096
  hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
  hparams.mesh_shape = "batch:8"
  return hparams
def xmoe_dense_4k():
  """Small transformer language model."""
  hparams = mtf_transformer.mtf_transformer_base()

  # The following hparams are constant across all these experiments.
  hparams.label_smoothing = 0.0
  hparams.batch_size = 128
  hparams.d_model = 512
  hparams.d_kv = 128
  hparams.num_heads = 4
  hparams.num_decoder_layers = 4
  hparams.shared_embedding_and_softmax_weights = False
  hparams.learning_rate_schedule = "rsqrt_decay"

  # We will vary the following parameters related to the ffn/moe layers.
  hparams.feedforward_layer = "dense_relu_dense"
  hparams.d_ff = 4096
  hparams.moe_num_experts = 16
  hparams.moe_overhead_train = 1.0
  hparams.moe_overhead_eval = 2.0
  hparams.moe_loss_coef = 1e-3
  hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
  hparams.mesh_shape = "batch:8"
  return hparams