def moe(self, x, layout, mesh_shape, input_mask, is_training): """Mixture of experts layer. TODO(noam): clean up the mixture-of-experts code in Transformer. Args: x: layer input layout: a mtf.LayoutRules mesh_shape: a mtf.Shape input_mask: a mtf.Tensor is_training: a boolean Returns: a mtf.Tensor (the layer output) """ hparams = moe.HParams( moe_gating="top_2", moe_num_experts=self.config.moe_num_experts, moe_loss_coef=1e-3, moe_hidden_size=self.config.moe_intermediate_size, moe_group_size=2048, moe_capacity_factor_train=1.25, moe_capacity_factor_eval=8.0, moe_use_second_place_loss=False, moe_second_policy_train="random", moe_second_policy_eval="random", moe_second_threshold_train=0.2, moe_second_threshold_eval=0.2, moe_dropout_rate=0.0, moe_use_experts_attention=False, moe_min_expert_capacity=4) layer_output, loss = moe.transformer_moe_layer_v1( inputs=x, output_dim=self.model_dim, hparams=hparams, train=is_training, variable_dtype=tf.float32, layout=layout, mesh_shape=mesh_shape, nonpadding=(mtf.cast(input_mask, tf.float32) if input_mask else None), activation=get_activation( self.config.feedforward_intermediate_act)) self._extra_losses.append(loss) return layer_output
def __init__(self, num_experts=16, loss_coef=1e-2, hidden_size=4096, group_size=1024, capacity_factor_train=1.25, capacity_factor_eval=2.0, use_second_place_loss=False, second_policy_train="random", second_policy_eval="random", second_threshold_train=0.2, second_threshold_eval=0.2, dropout_rate=0.0, activation="relu", moe_gating="top_2", min_expert_capacity=4, switch_policy_train="input_jitter", switch_policy_eval="input_jitter", switch_dropout=0.1, switch_temperature=1.0, switch_jitter=1e-2, ntlb_top_k=4, output_dim=None, z_loss=None, word_embed_mode=None, use_second_place_expert_prob=None, use_second_place_expert_prob_temp=None, moe_num_layers=1, heterogeneous_mask_info=None, top_n_num_experts_per_token=3): self._hparams = moe.HParams( moe_gating=moe_gating, moe_num_experts=num_experts, moe_loss_coef=loss_coef, moe_hidden_size=hidden_size, moe_group_size=group_size, moe_min_expert_capacity=min_expert_capacity, moe_capacity_factor_train=capacity_factor_train, moe_capacity_factor_eval=capacity_factor_eval, moe_use_second_place_loss=use_second_place_loss, moe_second_policy_train=second_policy_train, moe_second_policy_eval=second_policy_eval, moe_second_threshold_train=second_threshold_train, moe_second_threshold_eval=second_threshold_eval, moe_dropout_rate=dropout_rate, moe_switch_policy_train=switch_policy_train, moe_switch_policy_eval=switch_policy_eval, moe_switch_dropout=switch_dropout, moe_switch_temperature=switch_temperature, moe_switch_jitter=switch_jitter, moe_output_dim=output_dim, moe_ntlb_top_k=ntlb_top_k, moe_z_loss=z_loss, moe_word_embed_mode=word_embed_mode, moe_use_second_place_expert_prob=(use_second_place_expert_prob), moe_use_second_place_expert_prob_temp=( use_second_place_expert_prob_temp), moe_top_n_num_experts_per_token=top_n_num_experts_per_token, moe_num_layers=moe_num_layers, moe_heterogeneous_mask_info=heterogeneous_mask_info) self._activation = activation