コード例 #1
0
def xmoe2_v1():
    """Model incorporating mixture-of-experts and local-attention.

  ~6B parameters

  32 experts in 3 hierarchichal moe layers.

  Returns:
    a hparams
  """
    hparams = xmoe2_dense(0)
    moe.set_default_moe_hparams(hparams)
    hparams.decoder_layers = ([
        "local_att", "local_att", "drd", "att", "drd", "local_att",
        "local_att", "hmoe"
    ] * 4)[:-1]
    hparams.d_ff = 2048
    hparams.d_kv = 128
    hparams.moe_hidden_size = 32768
    hparams.mesh_shape = "b0:4;b1:8"
    hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
    hparams.outer_batch_size = 4
    hparams.moe_num_experts = [8, 4]
    hparams.num_heads = 4
    return hparams
コード例 #2
0
def xmoe_top_2():
    """Mixture of experts (16 experts)."""
    hparams = xmoe_dense_4k()
    moe.set_default_moe_hparams(hparams)
    hparams.mesh_shape = "all:8"
    hparams.layout = "batch:all;experts:all"
    return hparams
コード例 #3
0
def xmoe_top_2():
  """Mixture of experts (16 experts)."""
  hparams = xmoe_dense_4k()
  moe.set_default_moe_hparams(hparams)
  hparams.mesh_shape = "all:8"
  hparams.layout = "batch:all;experts:all"
  return hparams
コード例 #4
0
ファイル: moe_experiments.py プロジェクト: zhyq/tensor2tensor
def xmoe_wiki_x32():
    """Two-dimensional hierarchical mixture of experts.

  (8x4 experts) * (16M params/expert) * 6 layers = 3B params

  Returns:
    a hparams object.
  """
    hparams = xmoe_wiki_base()
    moe.set_default_moe_hparams(hparams)
    hparams.feedforward_layer = "hmoe"
    hparams.moe_hidden_size = 8192
    hparams.mesh_shape = "b0:4;b1:8"
    hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
    hparams.outer_batch_size = 4
    hparams.moe_num_experts = [8, 4]
    return hparams
コード例 #5
0
ファイル: moe_experiments.py プロジェクト: zhyq/tensor2tensor
def mtf_transformer_lm_moe():
    """Mixture of experts language model.

  Compare to mtf_transformer.mtf_transformer_lm_baseline()

  Run this on 2x2 on languagemodel_lm1b32k_packed for 272000 steps (10 epochs)
  900M params.

  Results on LM1B:
         params/10^9  log-ppl(per-token)
         0.90         TODO(noam): rerun experiment

  Returns:
    a hparams
  """
    hparams = mtf_transformer.mtf_transformer_lm_baseline()
    moe.set_default_moe_hparams(hparams)
    hparams.mesh_shape = "all:8"
    hparams.layout = "batch:all;experts:all"
    hparams.feedforward_layer = "moe"
    return hparams
コード例 #6
0
def xmoe_wiki_x():
    """Baseline set of parameters for mixture-of-experts.

  ~6B parameters

  Returns:
    a hparams
  """
    hparams = xmoe_wiki_base(0)
    moe.set_default_moe_hparams(hparams)
    hparams.decoder_layers = (["att", "drd", "att", "drd", "att", "hmoe"] * 3 +
                              ["att", "drd", "att", "drd"])
    hparams.d_ff = 2048
    hparams.d_kv = 128
    hparams.moe_hidden_size = 32768
    hparams.mesh_shape = "b0:4;b1:8"
    hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
    hparams.outer_batch_size = 4
    hparams.moe_num_experts = [8, 4]
    hparams.num_heads = 4
    return hparams
コード例 #7
0
def wiki_2x2_base():
    """Set of architectural experiments - language model on wikipedia on a 2x2.

  1 epoch = ~180k steps at batch size 32 - we may never finish an epoch!

  Returns:
    a hparams
  """
    hparams = mtf_transformer.mtf_transformer_base_lm()
    hparams.shared_embedding_and_softmax_weights = False
    # no dropout - dataset is big enough to avoid overfitting.
    hparams.attention_dropout = 0.0
    hparams.relu_dropout = 0.0
    hparams.layer_prepostprocess_dropout = 0.0
    hparams.max_length = 1024
    # 4 sequences per core
    hparams.batch_size = 32
    # We don't use linear decay in these experiments, since we don't want
    # a sharp jump in quality at the end of the training schedule.
    # You can insert this once you find the right architecture.
    hparams.learning_rate_schedule = "rsqrt_decay"
    hparams.mesh_shape = "all:8"
    hparams.layout = "batch:all;experts:all"

    # parameters for mixture-of-experts
    moe.set_default_moe_hparams(hparams)
    hparams.moe_num_experts = 16
    hparams.moe_hidden_size = 8192

    hparams.decoder_layers = ["att", "drd"] * 6
    hparams.d_model = 1024
    hparams.d_ff = 2048
    hparams.d_kv = 128
    hparams.num_heads = 4

    return hparams
コード例 #8
0
def wiki_2x2_base():
  """Set of architectural experiments - language model on wikipedia on a 2x2.

  1 epoch = ~180k steps at batch size 32 - we may never finish an epoch!

  Returns:
    a hparams
  """
  hparams = mtf_transformer.mtf_transformer_base_lm()
  hparams.shared_embedding_and_softmax_weights = False
  # no dropout - dataset is big enough to avoid overfitting.
  hparams.attention_dropout = 0.0
  hparams.relu_dropout = 0.0
  hparams.layer_prepostprocess_dropout = 0.0
  hparams.max_length = 1024
  # 4 sequences per core
  hparams.batch_size = 32
  # We don't use linear decay in these experiments, since we don't want
  # a sharp jump in quality at the end of the training schedule.
  # You can insert this once you find the right architecture.
  hparams.learning_rate_schedule = "rsqrt_decay"
  hparams.mesh_shape = "all:8"
  hparams.layout = "batch:all;experts:all"

  # parameters for mixture-of-experts
  moe.set_default_moe_hparams(hparams)
  hparams.moe_num_experts = 16
  hparams.moe_hidden_size = 8192

  hparams.decoder_layers = ["att", "drd"] * 6
  hparams.d_model = 1024
  hparams.d_ff = 2048
  hparams.d_kv = 128
  hparams.num_heads = 4

  return hparams
コード例 #9
0
def xmoe2_v1():
  """Model incorporating mixture-of-experts and local-attention.

  ~6B parameters

  32 experts in 3 hierarchichal moe layers.

  Returns:
    a hparams
  """
  hparams = xmoe2_dense(0)
  moe.set_default_moe_hparams(hparams)
  hparams.decoder_layers = (
      ["local_att", "local_att", "drd",
       "att", "drd", "local_att", "local_att", "hmoe"] * 4)[:-1]
  hparams.d_ff = 2048
  hparams.d_kv = 128
  hparams.moe_hidden_size = 32768
  hparams.mesh_shape = "b0:4;b1:8"
  hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
  hparams.outer_batch_size = 4
  hparams.moe_num_experts = [8, 4]
  hparams.num_heads = 4
  return hparams