Exemplo n.º 1
0
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output
Exemplo n.º 2
0
def dot_product_attention(q,
                          k,
                          v,
                          mask,
                          dropout=0.0,
                          dropout_broadcast_dims=None):
    """Dot-product attention.

  Args:
    q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions
      are [batch, heads].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    mask: mask Tensor (see attention_mask())
    dropout: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    length_kv = k.shape.dims[-2]
    logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv])
    logits = mtf.einsum([q, k], logits_shape)
    if mask is not None:
        logits += mask
    weights = mtf.softmax(logits, length_kv)
    if dropout != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    depth_v = v.shape.dims[-1]
    outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v])
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Exemplo n.º 3
0
def moe_v0(inputs,
           hidden_dim,
           output_dim,
           experts_dim,
           loss_coef=1e-3,
           overhead=1.0):
    """Local mixture of experts that works well on TPU.

  See https://arxiv.org/abs/1701.06538

  There are num_experts expert networks, each containing a relu-activated
  hidden layer of size hidden_size, followed by an output projection.

  The number of parameters is thus:
    num_experts * (input_size * hidden_size + hidden_size * output_size)

  The input is 3d: [batch, length, depth], consisting of the representations
  of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    as opposed to on individual sequences.  This would allow more freedom
    for individual sequences to be unbalanced.  Unfortunately, that would
    slow down our hacked-up gather-by-matmul implementation.

    TODO(noam): There is no real reason for a single sequence to be the unit
      of equal allocation.  Reshaping the inputs would allow us to pick a
      different unit of equal allocation.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.  We also want to integrate this
  gating/dispatching logic into multi-device mixtures-of-experts.

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim]
    hidden_dim: a mtf.Dimension
    output_dim: a mtf.Dimension
    experts_dim: a mtf.Dimension
    loss_coef: a float scalar
    overhead: multiplicative factor of how much spare capacity to assign

  Returns:
    outputs: a Tensor with shape [batch_dim, length_dim, output_dim]
    loss: a mtf scalar
  """
    batch_dim, length_dim, input_dim = inputs.shape.dims

    # Each sequence sends expert_capacity positions to each expert.
    expert_capacity = min(
        length_dim.size,
        int((length_dim.size * 2 * overhead) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    # This is the learned gating function.
    # shape = [batch_dim, length_dim, experts_dim_unsplit]
    gates = mtf.softmax(dense(inputs, experts_dim_unsplit),
                        experts_dim_unsplit)

    assignment_shape = mtf.TensorShape(
        [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim])

    backward_assignment = mtf.slicewise(functools.partial(
        _truncated_top_2_gating, expert_capacity=expert_capacity), [gates],
                                        output_shape=assignment_shape,
                                        splittable_dims=[batch_dim],
                                        name="backward_assignment")

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  inputs.dtype)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, forward_assignment],
                               mtf.TensorShape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.TensorShape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = dense(expert_inputs,
              hidden_dim,
              expert_dims=[experts_dim],
              activation=mtf.relu,
              name="x0")
    expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.TensorShape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, backward_assignment],
                        mtf.TensorShape([batch_dim, length_dim, output_dim]))

    importance = mtf.reduce_sum(backward_assignment,
                                output_shape=mtf.TensorShape(
                                    [batch_dim, experts_dim_unsplit]))

    loss = cv_squared(importance) * loss_coef
    return output, loss
Exemplo n.º 4
0
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q,
            mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        mblocks = local_k.shape.dims[2]
        mlength = local_k.shape.dims[3]
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        # Compatibility between q and k for rest of the blocks.
        # Shape [batch, heads, num_blocks - 1, block_length, local_length]
        attention = mtf.einsum([tail_q, local_k],
                               mtf.TensorShape(
                                   [batch, heads, mblocks, blength, mlength]))
        attention += mask
        attention = mtf.softmax(attention, mlength)

        # Run attention for rest of the blocks.
        # Shape [batch, heads, num_blocks-1, block_length, kv_channels]
        output = mtf.einsum([attention, local_v],
                            mtf.TensorShape(
                                [batch, heads, mblocks, blength, kv_channels]))
        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, output], num_blocks.name)
        final_output = mtf.reshape(
            final_output,
            mtf.TensorShape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
Exemplo n.º 5
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  importance=None):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.softmax(
        mtf_layers.dense(inputs,
                         experts_dim,
                         use_bias=False,
                         expert_dims=outer_expert_dims), experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    density_1 = mtf.Print(
        density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])],
        "density_1",
        summarize=1000)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    policy = (hparams.moe_second_policy_train
              if train else hparams.moe_second_policy_eval)
    threshold = (hparams.moe_second_threshold_train
                 if train else hparams.moe_second_threshold_eval)
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)
    mask_2 = mtf.Print(mask_2,
                       [mtf.reduce_mean(mask_2, output_shape=[experts_dim])],
                       "density_2",
                       summarize=1000)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
Exemplo n.º 6
0
def _top_2_gating(inputs, experts_dim, expert_capacity_dim, max_experts,
                  hparams, train):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  max_experts is an float tensor with shape [batch_dim, group_dim]
  indicating at most how many experts to use per example.  This can be
  used to prevent padding from going to experts.

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned backward_assignment is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of all of these
  are as follows.

  inputs: [batch_dim, group_dim, input_dim]
  forward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim]
  expert_inputs: [batch_dim, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [batch_dim, experts_dim, expert_capacity_dim, output_dim]
  backward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim]
  outputs: [batch_dim, group_dim, output_dim]

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, group_dim, input_dim]
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    max_experts: optional mtf.Tensor with shape [batch_dim, group_dim]
    hparams: model hyperparameters.
    train: a boolean

  Returns:
    forward_assignment: a Tensor with shape
      [batch_dim, group_dim, experts_dim, expert_capacity_dim]
    backward_assignment: a Tensor with shape
      [batch_dim, group_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    unused_batch_dim, group_dim, unused_input_dim = inputs.shape.dims

    raw_gates = mtf.softmax(
        mtf_layers.dense(inputs, experts_dim, use_bias=False), experts_dim)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)

    if max_experts is not None:
        geq1 = mtf.to_float(mtf.greater_equal(max_experts, 1.0))
        geq2 = mtf.to_float(mtf.greater_equal(max_experts, 2.0))
        mask_1 *= geq1
        mask_2 *= geq2
        raw_gates *= geq1
        gates_without_top_1 *= geq2

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_dim)
    density_1 = mtf.Print(
        density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])],
        "density_1",
        summarize=1000)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    policy = (hparams.moe_second_policy_train
              if train else hparams.moe_second_policy_eval)
    threshold = (hparams.moe_second_threshold_train
                 if train else hparams.moe_second_threshold_eval)
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)
    mask_2 = mtf.Print(mask_2,
                       [mtf.reduce_mean(mask_2, output_shape=[experts_dim])],
                       "density_2",
                       summarize=1000)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (mtf.cumsum(mask_2, group_dim, exclusive=True) +
                            mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # renormalize the two gate values to add up to 1
    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # [batch, group, experts, expert_capacity]
    backward_assignment = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  backward_assignment.dtype)

    return forward_assignment, backward_assignment, loss