Ejemplo n.º 1
0
def compute_batch_size(sequence_length,
                       mesh_shape,
                       layout_rules,
                       method_and_value):
  """Compute the total batch size in sequences.

  method_and_value is a (string, int) pair.
  The method string is one of the following four options:

  "sequences_per_batch"
  "tokens_per_batch"
  "sequences_per_replica"
  "tokens_per_replica"

  According to the method string, the value represents either a number of
  sequences or a number of tokens, and represents either the size of the total
  batch or the fraction of the batch assigned to each model replica.

  For example ("tokens_per_replica", 2048) means that the batch size should be
  set so that the number of tokens per model replica is 2048.  So if the
  sequence length is 1024 and there is 16-way data-parallelism, then the number
  of sequences per batch would be 2048 * 16 / 1024 = 32.

  The "per_batch" versions are useful for ensuring indentical overall batch
  sizes across different mesh shapes/layouts.  The "per_replica" versions are
  useful for scaling up the total batch size relative to the degree of
  data-parallelism

  Args:
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    method_and_value: a pair
  Returns:
    an integer - the number of sequences per batch
  """
  def checkdiv(a, b):
    if a % b:
      raise ValueError("%d is not divisible by %d" % (a, b))
    return a // b
  num_replicas = (
      mtf.tensor_dim_to_mesh_dim_size(
          layout_rules, mesh_shape, mtf.Dimension("batch", 0)) *
      mtf.tensor_dim_to_mesh_dim_size(
          layout_rules, mesh_shape, mtf.Dimension("outer_batch", 0)))
  method, value = method_and_value
  if method == "sequences_per_batch":
    return value
  elif method == "tokens_per_batch":
    return checkdiv(value, sequence_length)
  elif method == "sequences_per_replica":
    return value * num_replicas
  elif method == "tokens_per_replica":
    return checkdiv(value, sequence_length) * num_replicas
  else:
    raise ValueError("unknown method %s" % method,)
Ejemplo n.º 2
0
 def LogDistribution(tsr, end='\n'):
     if log_distribution:
         print(tsr.shape)
         mesh_impl = mesh_to_impl[tsr.mesh]
         print('(', end='')
         for d in tsr.shape.dims:
             print(mtf.tensor_dim_to_mesh_dim_size(
                 mesh_impl.layout_rules, mesh_impl.shape, d),
                   end=', ')
         print(')', end=end, flush=True)
Ejemplo n.º 3
0
def auto_batch_size(sequence_length,
                    mesh_shape,
                    layout_rules,
                    tokens_per_split=2048):
    """Automatically compute batch size.

  Args:
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    tokens_per_split: an integer
  Returns:
    an integer
  """
    num_splits = mtf.tensor_dim_to_mesh_dim_size(layout_rules, mesh_shape,
                                                 mtf.Dimension("batch", 0))
    ret = max(1, tokens_per_split // sequence_length) * num_splits
    tf.logging.info("AUTO_BATCH_SIZE tokens_per_split=%s num_splits=%s"
                    " sequence_length=%s batch_size=%s" %
                    (tokens_per_split, num_splits, sequence_length, ret))
    return ret
Ejemplo n.º 4
0
def transformer_moe_layer_v2(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional mtf.Tensor with shape [a, b, l]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               inputs.shape.dims[:-1],
                               dtype=inputs.dtype) + nonpadding
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    expert_capacity = max(expert_capacity, 4)
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    expert_capacity = max(expert_capacity, 4)
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])
    if nonpadding is not None:
        nonpadding = mtf.reshape(nonpadding, [a0, g1, s])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            name="outer_gating",
            importance=nonpadding)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=importance,
            name="inner_gating")
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    hidden_output = mtf.layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wi")
    expert_output = mtf.layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wo")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
Ejemplo n.º 5
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  <B>: batch dims
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity
  (u for unsplit dims)

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [<batch_dims>, length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    # See "Dimensions cheat sheet"
    # <B>LM Tensor
    orig_inputs = inputs
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
                                                mesh_dim_size)

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # dispatch_tensor and combine_tensor are
        # <B>GSEC Tensors
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Now feed the expert inputs through the experts.
    h = mtf.layers.dense(expert_inputs,
                         hidden_dim,
                         expert_dims=[experts_dim],
                         activation=mtf.relu,
                         use_bias=False,
                         variable_dtype=variable_dtype,
                         name="wi")

    expert_output = mtf.layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wo")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))

    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])

    return output, loss * hparams.moe_loss_coef
Ejemplo n.º 6
0
def maybe_reshape_attention_input_for_2d_sharding(
    context, q, k, v, bias, unsplittable_dims):
  """Reshape the inputs to attention to split over an unused mesh dimension.

  In the case where the attention computation is unnecessarily replicated,
  this function reshapes the attention inputs to remove the unnecessary
  replication.

  This becomes relevent when doing 2-dimenional model parallelism.
  d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are
  sharded over the other mesh dimension.  This fully distributes all of the
  einsum operations, except for the internals of the attention computation.

  To distribute that computation, this function creates a new tensor-dimension
  from the low bits of either the batch dimension or the num_heads dimension,
  and then splits that dimension over the unused mesh dimension.

  Args:
    context: a transformer.Context
    q: a Tensor
    k: a Tensor
    v: a Tensor
    bias: a Tensor
    unsplittable_dims: a list of tensor-dimensions not to split.  The key/value
      dimensions should be passed here.
  Returns:
    reshaped_q: a Tensor
    reshaped_k: a Tensor
    reshaped_v: a Tensor
    reshaped_bias: a Tensor
  """
  original_inputs = q, k, v, bias
  # we need to know the layout and mesh-shape to figure out what to do.
  if not context or not context.model.layout or not context.model.mesh_shape:
    return original_inputs
  mesh_shape = mtf.convert_to_shape(context.model.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(context.model.layout)
  # find a mesh dim that is unused (no tensor-dimension is split across it)
  mesh_axis_used = [False] * mesh_shape.ndims
  for x in original_inputs:
    for mesh_axis in layout_rules.tensor_layout(
        x.shape, mesh_shape).tensor_axis_to_mesh_axis:
      if mesh_axis is not None:
        mesh_axis_used[mesh_axis] = True
  if False not in mesh_axis_used:
    return original_inputs
  mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)]
  # Choose an appropriate name for the new tensor-dimension so that the layout
  #   will know to split it across the unused mesh dimension.
  tensor_dim_name = None
  tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names(
      mesh_dim.name)
  if tensor_dim_name:
    tensor_dim_name = tensor_dim_name[0]
  else:
    return original_inputs
  # Find a tensor-dimension that we can further split, by breaking off the
  # lower bits into our new tensor-dimension.
  # This resplittable tensor-dimension must be presnent in all of q, k, v
  #   and must be large enough to be further split.
  resplittable_dim = None
  for d in q.shape.dims:
    if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims:
      num_splits = mtf.tensor_dim_to_mesh_dim_size(
          context.model.layout, context.model.mesh_shape, d)
      if d.size % (num_splits * mesh_dim.size) == 0:
        resplittable_dim = d
        break
  if not resplittable_dim:
    return original_inputs
  new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits)
  new_dim_low = mtf.Dimension(tensor_dim_name,
                              resplittable_dim.size // num_splits)
  def _my_reshape(x):
    if x and resplittable_dim in x.shape.dims:
      return mtf.replace_dimensions(
          x, resplittable_dim, [new_dim_high, new_dim_low])
    else:
      return x
  return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
Ejemplo n.º 7
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None,
                             activation=mtf.relu,
                             num_microbatches=None,
                             token_embeddings=None,
                             context=None):
    """Local heterogenous mixture of experts.

  See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
  a generic moe layer.

  The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
  dimension [maximum hidden size, maximum # layers, # experts] and its shape
  will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
  The layer-specific mask slice is applied at each expert layer to the
  activation which is [expert width, # experts]. If the heterogeneous_mask_info
  is None, there is no mask applied and the code is equivalent to the
  homogeneous case.


  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.
    context: a Context.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs

    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    if hparams.moe_heterogeneous_mask_info is not None:
        tf.logging.info("moe_heterogeneous_mask_info: {}".format(
            hparams.moe_heterogeneous_mask_info))
        heterogeneous_mask = generate_heterogeneous_expert_masks(
            hparams.moe_heterogeneous_mask_info,
            hparams.moe_num_experts,
            experts_dim,
            mesh=inputs.mesh,
            expert_width=hparams.moe_hidden_size)
        # overwrite depth and width with the mask maximum dimension
        hparams.moe_num_layers = heterogeneous_mask.shape[1].size
        hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = moe._split_into_groups(  # pylint: disable=protected-access
        n, hparams.moe_group_size, mesh_dim_size)
    # TODO(barretzoph): implementation without pylint calls?

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Token embeddings that can be optionally used in the router for determining
    # where to send tokens.
    if hparams.moe_word_embed_mode is not None:
        token_embeddings = mtf.cast(
            mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
    tf.logging.info("expert_capacity: %d" % expert_capacity)
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # combine_tensor,
        # dispatch_tensor  OG`SEC Tensors
        # (G is generally split along mesh dim)
        dispatch_tensor, combine_tensor, loss = moe._top_2_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "top_n":
        dispatch_tensor, combine_tensor, loss = moe._top_n_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch":
        dispatch_tensor, combine_tensor, loss = moe._switch_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "ntlb":
        dispatch_tensor, combine_tensor, loss = moe._ntlb_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch_max":
        dispatch_tensor, combine_tensor, loss = moe._switch_max_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "expert_selection":
        dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            group_size_dim=group_size_dim,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            name="expert_selection_gating",
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over batch -> split over experts
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
    for layer_idx in range(hparams.moe_num_layers):
        with tf.variable_scope("expert_layer_{}".format(layer_idx)):
            res_h = 0.0
            if layer_idx > 0:
                res_h = expert_inputs
                expert_inputs = transformer.sublayer_rms_norm(
                    expert_inputs, None, context)

            # Now feed the expert inputs through the experts.
            h = mtf.layers.dense_product(
                expert_inputs,
                reduced_dims=expert_inputs.shape.dims[-1:],
                new_dims=[hidden_dim],
                expert_dims=[experts_dim],
                activation_functions=activation,
                use_bias=False,
                variable_dtype=variable_dtype,
                name="wi")

            # apply dropout
            if hparams.moe_dropout_rate != 0.0:
                h = mtf.dropout(h,
                                is_training=train,
                                keep_prob=1.0 - hparams.moe_dropout_rate)
            # only if heterogeneous
            if hparams.moe_heterogeneous_mask_info is not None:
                # Get mask for current layer by slicing heterogeneous mask
                heterogeneous_mask_slice = mtf.slice(heterogeneous_mask,
                                                     layer_idx, 1,
                                                     "num_expert_layers")

                # Get rid of the expert layers dimension.
                heterogeneous_mask_slice = mtf.reshape(
                    heterogeneous_mask_slice, [
                        heterogeneous_mask_slice.shape[0],
                        heterogeneous_mask_slice.shape[-1]
                    ])
                h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
            expert_output = mtf.layers.dense(h,
                                             output_dim,
                                             expert_dims=[experts_dim],
                                             use_bias=False,
                                             reduced_dims=h.shape.dims[-1:],
                                             variable_dtype=variable_dtype,
                                             name="wo")

            if layer_idx < (hparams.moe_num_layers - 1):
                expert_output = transformer.sublayer_dropout(
                    expert_output, None, context)
            expert_output += res_h
            expert_inputs = expert_output

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output, loss * hparams.moe_loss_coef
Ejemplo n.º 8
0
def transformer_moe_layer_v1(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
    num_microbatches=None):
  """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  # pylint: disable=line-too-long
  #
  # O outer_batch dimension can be used for expert replication, e.g.
  # outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each
  # expert.
  #
  # E.g. 16x16 basic example:
  #   moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024
  # ---
  # Below ` indicates common way of splitting along mesh dimension.
  #
  # orig_inputs      OB`LM Tensor
  #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
  #                  v (reshaped)
  # inputs           OG`SM
  #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024]
  #
  # combine_tensor,
  # dispatch_tensor  OG`SEC
  #                  Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4]
  #
  # (dispatched inputs)
  # expert_inputs    OEG`CM
  #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
  #                  v (re-split via ReshapeOperation)
  #                  OE`GCM
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
  #
  # (hidden representation)
  # h                OE`GCH
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192]
  #
  # expert_output    OE`GCM
  #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
  #                  v (re-split via ReshapeOperation)
  #                  OEG`CM
  #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
  #
  # (combined expert_output)
  # output           OG`SM
  #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024
  #                  v (reshape)
  #                  OB`LM
  #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
  #
  # pylint: enable=line-too-long
  orig_inputs = inputs
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
  experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups is a multiple of the mesh dimension
  # over which those groups are split.
  batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                      orig_inputs.shape.dims[-1])
  # Hack: we assume that
  #   "outer_batch" == replication of experts
  #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
  #
  # We then reqire num_groups to be a multiple of mesh_dim_size.
  if orig_inputs.shape.dims[0].name == "outer_batch":
    outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
  else:
    outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                       orig_inputs.shape.dims[0])

  # Number of MoE inputs (total number of position across batch_and_length_dims
  # per replica.
  n = 1
  for d in batch_and_length_dims:
    n *= d.size

  n = n // outer_batch_dim.size

  mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                  orig_batch_dim)
  num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
                                              mesh_dim_size)

  group_size_dim = mtf.Dimension("group", group_size)
  num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

  moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
  # OGSM Tensor
  inputs = mtf.reshape(inputs, moe_input_dims)

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
  tf.logging.info("expert_capacity: %d" % expert_capacity)
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(
        inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
  if hparams.moe_gating == "top_2":
    # combine_tensor,
    # dispatch_tensor  OG`SEC Tensors
    # (G is generally split along mesh dim)
    dispatch_tensor, combine_tensor, loss = _top_2_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches)
  elif hparams.moe_gating == "rand_1":
    dispatch_tensor, combine_tensor, loss = _rand_1_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches)
  elif hparams.moe_gating == "switch":
    dispatch_tensor, combine_tensor, loss = _switch_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding,
        num_microbatches=num_microbatches)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                             mtf.Shape([
                                 outer_batch_dim, experts_dim_unsplit,
                                 num_groups_dim, expert_capacity_dim, input_dim
                             ]))

  expert_inputs = mtf.reshape(
      expert_inputs,
      mtf.Shape([
          outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
          input_dim
      ]))

  # Now feed the expert inputs through the experts.
  h = mtf.layers.dense_product(
      expert_inputs,
      reduced_dims=expert_inputs.shape.dims[-1:],
      new_dims=[hidden_dim],
      expert_dims=[experts_dim],
      activation_functions=activation, use_bias=False,
      variable_dtype=variable_dtype, name="wi")

  if train and hparams.moe_dropout_rate != 0.0:
    h = mtf.dropout(h, 1.0 - hparams.moe_dropout_rate)

  def _compute_output(hidden, layer_name):
    """Compute the output of the attention layer from the hidden vector."""
    expert_output = mtf.layers.dense(
        hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
        reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
        name=layer_name)

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output

  if hparams.moe_use_experts_attention:
    # We share k_h and v_h with no degradation in performance
    q_h, k_h = h, h
    outputs = []
    q = _compute_output(q_h, layer_name="q_wo")
    k = _compute_output(k_h, layer_name="k_wo")
    outputs.append(q)
    outputs.append(k)
    return outputs, loss * hparams.moe_loss_coef
  else:
    output = _compute_output(h, layer_name="wo")
    return output, loss * hparams.moe_loss_coef