예제 #1
0
def mtr_lm_v1(num_heads=8, num_memory_heads=0):
    """Model incorporating mixture-of-experts, local and global attention.

  ~6B parameters

  32 experts in 3 hierarchichal moe layers.

  Args:
    num_heads: an optional integer
    num_memory_heads: an optional integer

  Returns:
    a hparams
  """
    hparams = mtr_lm_dense(0)
    local_att = transformer_layers.LocalSelfAttention(
        num_heads=num_heads,
        num_memory_heads=num_memory_heads,
        key_value_size=128)
    att = transformer_layers.SelfAttention(num_heads=num_heads,
                                           num_memory_heads=num_memory_heads,
                                           key_value_size=128)
    drd = transformer_layers.DenseReluDense(hidden_size=2048)
    hmoe = moe.MoE2D(expert_x=8, expert_y=4, hidden_size=32768)
    hparams.layer_stack = transformer.LayerStack(
        ([local_att, local_att, drd, att, drd, local_att, local_att, hmoe] *
         4)[:-1])
    hparams.mesh_shape = "b0:4;b1:8"
    hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
    hparams.outer_batch_size = 4
    return hparams
예제 #2
0
def local_self_attention_layer(hparams, prefix):
    """Create self-attention layer based on hyperparameters."""
    return transformer_layers.LocalSelfAttention(
        num_heads=hparams.get(prefix + "num_heads"),
        num_memory_heads=hparams.get(prefix + "num_memory_heads"),
        radius=hparams.local_attention_radius,
        key_value_size=hparams.d_kv,
        shared_kv=hparams.get(prefix + "shared_kv", False),
        attention_kwargs=attention_kwargs_from_hparams(hparams))
예제 #3
0
def mtf_unitransformer_all_layers_tiny():
  """Test out all the layers on local CPU."""
  hparams = mtf_unitransformer_tiny()
  hparams.layer_stack = transformer.LayerStack(
      [transformer_layers.SelfAttention(num_heads=4),
       transformer_layers.LocalSelfAttention(num_heads=4),
       moe.MoE1D(num_experts=4, hidden_size=512),
       moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512),
       transformer_layers.DenseReluDense(hidden_size=512)])
  return hparams
예제 #4
0
def mtf_transformer2_all_layers_tiny():
    """Test out all the layers on local CPU."""
    hparams = mtf_transformer2_base()
    hparams.batch_size = 2
    hparams.mesh_shape = ""
    hparams.d_model = 128
    hparams.layer_stack = transformer.LayerStack([
        transformer_layers.SelfAttention(num_heads=4),
        transformer_layers.LocalSelfAttention(num_heads=4),
        moe.MoE1D(num_experts=4, hidden_size=512),
        moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512),
        transformer_layers.DenseReluDense(hidden_size=512)
    ])
    return hparams