Пример #1
0
def mtf_unitransformer_all_layers_tiny():
  """Test out all the layers on local CPU."""
  hparams = mtf_unitransformer_tiny()
  hparams.layer_stack = transformer.LayerStack(
      [transformer_layers.SelfAttention(num_heads=4),
       transformer_layers.LocalSelfAttention(num_heads=4),
       moe.MoE1D(num_experts=4, hidden_size=512),
       moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512),
       transformer_layers.DenseReluDense(hidden_size=512)])
  return hparams
Пример #2
0
def mtf_transformer2_all_layers_tiny():
    """Test out all the layers on local CPU."""
    hparams = mtf_transformer2_base()
    hparams.batch_size = 2
    hparams.mesh_shape = ""
    hparams.d_model = 128
    hparams.layer_stack = transformer.LayerStack([
        transformer_layers.SelfAttention(num_heads=4),
        transformer_layers.LocalSelfAttention(num_heads=4),
        moe.MoE1D(num_experts=4, hidden_size=512),
        moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512),
        transformer_layers.DenseReluDense(hidden_size=512)
    ])
    return hparams
Пример #3
0
def moe_1d_layer(hparams, prefix):
    del prefix
    return moe.MoE1D(num_experts=hparams.moe_num_experts,
                     hidden_size=hparams.moe_hidden_size)
def moe_1d_layer(hparams, prefix):
    del prefix
    return moe.MoE1D(num_experts=hparams.moe_num_experts,
                     model_d=hparams.moe_model_d)