Esempio n. 1
0
def xmoe_dense_4k():
    """Series of architectural experiments on cheap language models.

  For all of these architectures, we run on languagemodel_lm1b8k_packed
  for 32000 steps.

  All log-perplexities are per-token - multiply by 1.298 for per-word

  Results:
  model             params(M)  einsum  alltoall  mxu-util  log-ppl
  xmoe_dense_4k     30         3.0e12  0         45%        3.31
  xmoe_dense_8k     46         4.7e12  0         49%        3.24
  xmoe_dense_64k    282        2.8e13  0                    3.06
  xmoe_top_2        282        4.0e12  3.4e8     36%        3.07
  xmoe_top_2_c15    282        4.5e12  4.0e8     38%        3.07
  xmoe_2d           282        5.3e12  7.6e8     34%        3.06

  Trained at 4x the batch size:
  xmoe_2d_88        1090       2.1e13  3.0e9     24%        3.07

  Note: configurations and code are likely to change without notice.

  Returns:
    a hparams
  """
    hparams = mtf_transformer.mtf_transformer_base_lm()
    hparams.attention_dropout = 0.0
    hparams.relu_dropout = 0.0
    hparams.layer_prepostprocess_dropout = 0.0

    # The following hparams are constant across all these experiments.
    hparams.batch_size = 128
    hparams.d_model = 512
    hparams.d_kv = 128
    hparams.num_heads = 4
    hparams.decoder_layers = ["att", "drd"] * 4
    hparams.shared_embedding_and_softmax_weights = False
    hparams.learning_rate_schedule = "rsqrt_decay"

    # We will vary the following parameters related to the ffn/moe layers.
    hparams.d_ff = 4096
    hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
    hparams.mesh_shape = "batch:8"
    return hparams
Esempio n. 2
0
def xmoe_dense_4k():
  """Series of architectural experiments on cheap language models.

  For all of these architectures, we run on languagemodel_lm1b8k_packed
  for 32000 steps.

  All log-perplexities are per-token - multiply by 1.298 for per-word

  Results:
  model             params(M)  einsum  alltoall  mxu-util  log-ppl
  xmoe_dense_4k     30         3.0e12  0         45%        3.31
  xmoe_dense_8k     46         4.7e12  0         49%        3.24
  xmoe_dense_64k    282        2.8e13  0                    3.06
  xmoe_top_2        282        4.0e12  3.4e8     36%        3.07
  xmoe_top_2_c15    282        4.5e12  4.0e8     38%        3.07
  xmoe_2d           282        5.3e12  7.6e8     34%        3.06

  Trained at 4x the batch size:
  xmoe_2d_88        1090       2.1e13  3.0e9     24%        3.07

  Note: configurations and code are likely to change without notice.

  Returns:
    a hparams
  """
  hparams = mtf_transformer.mtf_transformer_base_lm()
  hparams.attention_dropout = 0.0
  hparams.relu_dropout = 0.0
  hparams.layer_prepostprocess_dropout = 0.0

  # The following hparams are constant across all these experiments.
  hparams.batch_size = 128
  hparams.d_model = 512
  hparams.d_kv = 128
  hparams.num_heads = 4
  hparams.decoder_layers = ["att", "drd"] * 4
  hparams.shared_embedding_and_softmax_weights = False
  hparams.learning_rate_schedule = "rsqrt_decay"

  # We will vary the following parameters related to the ffn/moe layers.
  hparams.d_ff = 4096
  hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
  hparams.mesh_shape = "batch:8"
  return hparams
Esempio n. 3
0
def wiki_2x2_base():
    """Set of architectural experiments - language model on wikipedia on a 2x2.

  1 epoch = ~180k steps at batch size 32 - we may never finish an epoch!

  Returns:
    a hparams
  """
    hparams = mtf_transformer.mtf_transformer_base_lm()
    hparams.shared_embedding_and_softmax_weights = False
    # no dropout - dataset is big enough to avoid overfitting.
    hparams.attention_dropout = 0.0
    hparams.relu_dropout = 0.0
    hparams.layer_prepostprocess_dropout = 0.0
    hparams.max_length = 1024
    # 4 sequences per core
    hparams.batch_size = 32
    # We don't use linear decay in these experiments, since we don't want
    # a sharp jump in quality at the end of the training schedule.
    # You can insert this once you find the right architecture.
    hparams.learning_rate_schedule = "rsqrt_decay"
    hparams.mesh_shape = "all:8"
    hparams.layout = "batch:all;experts:all"

    # parameters for mixture-of-experts
    moe.set_default_moe_hparams(hparams)
    hparams.moe_num_experts = 16
    hparams.moe_hidden_size = 8192

    hparams.decoder_layers = ["att", "drd"] * 6
    hparams.d_model = 1024
    hparams.d_ff = 2048
    hparams.d_kv = 128
    hparams.num_heads = 4

    return hparams
Esempio n. 4
0
def wiki_2x2_base():
  """Set of architectural experiments - language model on wikipedia on a 2x2.

  1 epoch = ~180k steps at batch size 32 - we may never finish an epoch!

  Returns:
    a hparams
  """
  hparams = mtf_transformer.mtf_transformer_base_lm()
  hparams.shared_embedding_and_softmax_weights = False
  # no dropout - dataset is big enough to avoid overfitting.
  hparams.attention_dropout = 0.0
  hparams.relu_dropout = 0.0
  hparams.layer_prepostprocess_dropout = 0.0
  hparams.max_length = 1024
  # 4 sequences per core
  hparams.batch_size = 32
  # We don't use linear decay in these experiments, since we don't want
  # a sharp jump in quality at the end of the training schedule.
  # You can insert this once you find the right architecture.
  hparams.learning_rate_schedule = "rsqrt_decay"
  hparams.mesh_shape = "all:8"
  hparams.layout = "batch:all;experts:all"

  # parameters for mixture-of-experts
  moe.set_default_moe_hparams(hparams)
  hparams.moe_num_experts = 16
  hparams.moe_hidden_size = 8192

  hparams.decoder_layers = ["att", "drd"] * 6
  hparams.d_model = 1024
  hparams.d_ff = 2048
  hparams.d_kv = 128
  hparams.num_heads = 4

  return hparams