def mtr_lm_v1(num_heads=8, num_memory_heads=0): """Model incorporating mixture-of-experts, local and global attention. ~6B parameters 32 experts in 3 hierarchichal moe layers. Args: num_heads: an optional integer num_memory_heads: an optional integer Returns: a hparams """ hparams = mtr_lm_dense(0) local_att = transformer_layers.LocalSelfAttention( num_heads=num_heads, num_memory_heads=num_memory_heads, key_value_size=128) att = transformer_layers.SelfAttention(num_heads=num_heads, num_memory_heads=num_memory_heads, key_value_size=128) drd = transformer_layers.DenseReluDense(hidden_size=2048) hmoe = moe.MoE2D(expert_x=8, expert_y=4, hidden_size=32768) hparams.layer_stack = transformer.LayerStack( ([local_att, local_att, drd, att, drd, local_att, local_att, hmoe] * 4)[:-1]) hparams.mesh_shape = "b0:4;b1:8" hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" hparams.outer_batch_size = 4 return hparams
def mtf_unitransformer_all_layers_tiny(): """Test out all the layers on local CPU.""" hparams = mtf_unitransformer_tiny() hparams.layer_stack = transformer.LayerStack( [transformer_layers.SelfAttention(num_heads=4), transformer_layers.LocalSelfAttention(num_heads=4), moe.MoE1D(num_experts=4, hidden_size=512), moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512), transformer_layers.DenseReluDense(hidden_size=512)]) return hparams
def mtf_transformer2_all_layers_tiny(): """Test out all the layers on local CPU.""" hparams = mtf_transformer2_base() hparams.batch_size = 2 hparams.mesh_shape = "" hparams.d_model = 128 hparams.layer_stack = transformer.LayerStack([ transformer_layers.SelfAttention(num_heads=4), transformer_layers.LocalSelfAttention(num_heads=4), moe.MoE1D(num_experts=4, hidden_size=512), moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512), transformer_layers.DenseReluDense(hidden_size=512) ]) return hparams
def moe_2d_layer(hparams, prefix): del prefix return moe.MoE2D(expert_x=hparams.moe_expert_x, expert_y=hparams.moe_expert_y, hidden_size=hparams.moe_hidden_size)
def moe_2d_layer(hparams, prefix): del prefix return moe.MoE2D(expert_x=hparams.moe_expert_x, expert_y=hparams.moe_expert_y, model_d=hparams.moe_model_d)