def mtf_unitransformer_all_layers_tiny(): """Test out all the layers on local CPU.""" hparams = mtf_unitransformer_tiny() hparams.layer_stack = transformer.LayerStack( [transformer_layers.SelfAttention(num_heads=4), transformer_layers.LocalSelfAttention(num_heads=4), moe.MoE1D(num_experts=4, hidden_size=512), moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512), transformer_layers.DenseReluDense(hidden_size=512)]) return hparams
def mtf_transformer2_all_layers_tiny(): """Test out all the layers on local CPU.""" hparams = mtf_transformer2_base() hparams.batch_size = 2 hparams.mesh_shape = "" hparams.d_model = 128 hparams.layer_stack = transformer.LayerStack([ transformer_layers.SelfAttention(num_heads=4), transformer_layers.LocalSelfAttention(num_heads=4), moe.MoE1D(num_experts=4, hidden_size=512), moe.MoE2D(expert_x=4, expert_y=4, hidden_size=512), transformer_layers.DenseReluDense(hidden_size=512) ]) return hparams
def moe_1d_layer(hparams, prefix): del prefix return moe.MoE1D(num_experts=hparams.moe_num_experts, hidden_size=hparams.moe_hidden_size)
def moe_1d_layer(hparams, prefix): del prefix return moe.MoE1D(num_experts=hparams.moe_num_experts, model_d=hparams.moe_model_d)