コード例 #1
0
ファイル: bert.py プロジェクト: min-xu-ai/mesh
    def moe(self, x, layout, mesh_shape, input_mask, is_training):
        """Mixture of experts layer.

    TODO(noam): clean up the mixture-of-experts code in Transformer.

    Args:
      x: layer input
      layout: a mtf.LayoutRules
      mesh_shape: a mtf.Shape
      input_mask: a mtf.Tensor
      is_training: a boolean
    Returns:
      a mtf.Tensor (the layer output)
    """
        hparams = moe.HParams(
            moe_gating="top_2",
            moe_num_experts=self.config.moe_num_experts,
            moe_loss_coef=1e-3,
            moe_hidden_size=self.config.moe_intermediate_size,
            moe_group_size=2048,
            moe_capacity_factor_train=1.25,
            moe_capacity_factor_eval=8.0,
            moe_use_second_place_loss=False,
            moe_second_policy_train="random",
            moe_second_policy_eval="random",
            moe_second_threshold_train=0.2,
            moe_second_threshold_eval=0.2,
            moe_dropout_rate=0.0,
            moe_use_experts_attention=False,
            moe_min_expert_capacity=4)
        layer_output, loss = moe.transformer_moe_layer_v1(
            inputs=x,
            output_dim=self.model_dim,
            hparams=hparams,
            train=is_training,
            variable_dtype=tf.float32,
            layout=layout,
            mesh_shape=mesh_shape,
            nonpadding=(mtf.cast(input_mask, tf.float32)
                        if input_mask else None),
            activation=get_activation(
                self.config.feedforward_intermediate_act))
        self._extra_losses.append(loss)
        return layer_output
コード例 #2
0
 def __init__(self,
              num_experts=16,
              loss_coef=1e-2,
              hidden_size=4096,
              group_size=1024,
              capacity_factor_train=1.25,
              capacity_factor_eval=2.0,
              use_second_place_loss=False,
              second_policy_train="random",
              second_policy_eval="random",
              second_threshold_train=0.2,
              second_threshold_eval=0.2,
              dropout_rate=0.0,
              activation="relu",
              moe_gating="top_2",
              min_expert_capacity=4,
              switch_policy_train="input_jitter",
              switch_policy_eval="input_jitter",
              switch_dropout=0.1,
              switch_temperature=1.0,
              switch_jitter=1e-2,
              ntlb_top_k=4,
              output_dim=None,
              z_loss=None,
              word_embed_mode=None,
              use_second_place_expert_prob=None,
              use_second_place_expert_prob_temp=None,
              moe_num_layers=1,
              heterogeneous_mask_info=None,
              top_n_num_experts_per_token=3):
     self._hparams = moe.HParams(
         moe_gating=moe_gating,
         moe_num_experts=num_experts,
         moe_loss_coef=loss_coef,
         moe_hidden_size=hidden_size,
         moe_group_size=group_size,
         moe_min_expert_capacity=min_expert_capacity,
         moe_capacity_factor_train=capacity_factor_train,
         moe_capacity_factor_eval=capacity_factor_eval,
         moe_use_second_place_loss=use_second_place_loss,
         moe_second_policy_train=second_policy_train,
         moe_second_policy_eval=second_policy_eval,
         moe_second_threshold_train=second_threshold_train,
         moe_second_threshold_eval=second_threshold_eval,
         moe_dropout_rate=dropout_rate,
         moe_switch_policy_train=switch_policy_train,
         moe_switch_policy_eval=switch_policy_eval,
         moe_switch_dropout=switch_dropout,
         moe_switch_temperature=switch_temperature,
         moe_switch_jitter=switch_jitter,
         moe_output_dim=output_dim,
         moe_ntlb_top_k=ntlb_top_k,
         moe_z_loss=z_loss,
         moe_word_embed_mode=word_embed_mode,
         moe_use_second_place_expert_prob=(use_second_place_expert_prob),
         moe_use_second_place_expert_prob_temp=(
             use_second_place_expert_prob_temp),
         moe_top_n_num_experts_per_token=top_n_num_experts_per_token,
         moe_num_layers=moe_num_layers,
         moe_heterogeneous_mask_info=heterogeneous_mask_info)
     self._activation = activation