def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
                     finished_scores, finished_in_finished, *unused_states):
        """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the highest prob item in alive divided
    by the max length penalty

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_in_finished: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      Bool.
    """
        # TODO(noam): support a different decode length...
        # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32)

        # del alive_log_probs, finished_scores, finished_in_finished
        # return mtf.less(i, length_dim.size)
        if not stop_early:
            return mtf.less(i, decode_length)
        max_length_penalty = mtf.pow(((5. + mtf.to_float(decode_length)) / 6.),
                                     alpha)
        # The best possible score of the most likely alive sequence.
        lower_bound_alive_scores = mtf.gather(
            alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32),
            beam_dim) / max_length_penalty

        # Now to compute the lowest score of a finished sequence in finished
        # If the sequence isn't finished, we multiply it's score by 0. since
        # scores are all -ve, taking the min will give us the score of the lowest
        # finished item.
        lowest_score_of_finished_in_finished = mtf.reduce_min(
            finished_scores * mtf.to_float(finished_in_finished),
            reduced_dim=beam_dim)

        # If none of the sequences have finished, then the min will be 0 and
        # we have to replace it by -ve INF if it is. The score of any seq in alive
        # will be much higher than -ve INF and the termination condition will not
        # be met.
        lowest_score_of_finished_in_finished += ((1. - mtf.to_float(
            mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim))) *
                                                 -INF)

        bound_is_met = mtf.reduce_all(
            mtf.greater(lowest_score_of_finished_in_finished,
                        lower_bound_alive_scores))
        return mtf.logical_and(mtf.less(i, decode_length),
                               mtf.logical_not(bound_is_met))
    def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
                   states):
        """Given sequences and scores, will gather the top k=beam size sequences.

    Args:
      curr_seq: current topk sequence that has been grown by one position.
        [batch, beam, length]
      curr_scores: scores for each of these sequences. [batch_size, beam_size]
      curr_log_probs: log probs for each of these sequences.
        [batch, beam]
      curr_finished: Finished flags for each of these sequences.
        [batch, beam]
      states: list of mtf.Tensor
    Returns:
      Tuple of
        (Topk sequences based on scores,
         log probs of these sequences,
         Finished flags of these sequences)
    """
        # Set the scores of the finished seq in curr_seq to large negative
        # values
        curr_scores += mtf.to_float(curr_finished) * -INF
        return compute_topk_scores_and_seq(curr_seq, curr_scores,
                                           curr_log_probs, curr_finished,
                                           beam_dim, "grow_alive", states)
示例#3
0
def multihead_self_attention_incremental(query_antecedent,
                                         prev_k,
                                         prev_v,
                                         step_num,
                                         name="multihead_attention"):
  """Incremental self-attention (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = query_antecedent.shape.dims[:-1]
  io_channels = query_antecedent.shape.dims[-1]
  heads, memory_length, kv_channels = prev_k.shape.dims[-3:]
  with tf.variable_scope(name, default_name="multihead_attention"):
    q_var, k_var, v_var, o_var = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        query_antecedent.dtype)
    memory_antecedent = query_antecedent
    q = mtf.einsum(
        [query_antecedent, q_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, k_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, v_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = prev_k + mtf.multiply(
        k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape)
    v = prev_v + mtf.multiply(
        v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape)

    mask = mtf.to_float(mtf.greater(mtf.range(
        query_antecedent.mesh, memory_length, dtype=tf.int32), step_num)
                       ) * -1e9
    o = dot_product_attention(q, k, v, mask)
    y = mtf.einsum([o, o_var], query_antecedent.shape)
    return y, k, v
    def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq,
                      curr_scores, curr_finished):
        """Given sequences and scores, will gather the top k=beam size sequences.

    Args:
      finished_seq: Current finished sequences.
        [batch, beam, length]
      finished_scores: scores for each of these sequences.
        [batch, beam]
      finished_flags: finished bools for each of these sequences.
        [batch, beam]
      curr_seq: current topk sequence that has been grown by one position.
        [batch, beam, length]
      curr_scores: scores for each of these sequences. [batch, beam]
      curr_finished: Finished flags for each of these sequences.
        [batch, beam]
    Returns:
      Tuple of
        (Topk sequences based on scores,
         log probs of these sequences,
         Finished flags of these sequences,
         None (no states))
    """

        # Set the scores of the unfinished seq in curr_seq to large negative
        # values
        curr_scores += (1. - mtf.to_float(curr_finished)) * -INF
        unused_batch_dim, beam_dim, unused_length_dim = finished_seq.shape.dims

        # concatenating the sequences and scores along beam axis
        def _my_concat(a, b):
            a = mtf.rename_dimension(a, "beam", "triple_beam")
            b = mtf.rename_dimension(b, "double_beam", "triple_beam")
            return mtf.concat([a, b], "triple_beam")

        curr_finished_seq = _my_concat(finished_seq, curr_seq)
        curr_finished_scores = _my_concat(finished_scores, curr_scores)
        curr_finished_flags = _my_concat(finished_flags, curr_finished)
        return compute_topk_scores_and_seq(curr_finished_seq,
                                           curr_finished_scores,
                                           curr_finished_scores,
                                           curr_finished_flags,
                                           beam_dim,
                                           "grow_finished",
                                           states=None)
示例#5
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf_layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num in xrange(hparams.num_decoder_layers):
        with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num):
          q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
              mesh, self.heads_dim, self.model_dim,
              self.kv_dim, self.activation_dtype)
          k = mtf.einsum(
              [encoder_output, k_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
          v = mtf.einsum(
              [encoder_output, v_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
        encdec_tensors.append((q_var, o_var, k, v))
      partial_targets = None
    else:
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)

    if hparams.beam_size == 1:
      ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_kv_states = (
        [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)]
        * (2 * hparams.num_decoder_layers))
    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      self_attention_k = states[:hparams.num_decoder_layers]
      self_attention_v = states[hparams.num_decoder_layers:]
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_self_attention_k, new_self_attention_v = (
            self._decoder_layer_stack_incremental(
                x,
                step_num,
                encdec_tensors,
                self_attention_k,
                self_attention_v,
                encdec_attention_mask=encoder_attention_mask))
      logits = mtf.matmul(x, softmax_var)
      return logits, new_self_attention_k + new_self_attention_v

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf_beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_kv_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if self.has_input:
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf_beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_kv_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
示例#6
0
def _truncated_top_2_gating_mtf(
    gates, group_dim, experts_dim, expert_capacity_dim):
  """Compute gating for mixture-of-experts in TensorFlow.

  gates is usually the output of a softmax function.
  The return value is a dense representation of the mapping between
  the input positions in the positions in the batches sent to the experts.

  TODO(noam): this function contains code factored out of
  expert_utils.local_moe_tpu.  Move this function to that file and
  call it from both places.

  Args:
    gates: a Tensor
    group_dim: one dimension of gates
    experts_dim: one dimension of gates
    expert_capacity_dim: a Dimension not in gates

  Returns:
    a Tensor with shape gates.shape + expert_capacity_dim

  Raises:
    ValueError: if group_dim has size >256
  """
  gates = mtf.to_float(gates)
  expert_capacity_f = float(expert_capacity_dim.size)
  # Find the top expert for each position. shape=[batch, group]
  index_1, gate_1 = mtf.top_1(gates, experts_dim)
  # [batch, group, experts]
  mask_1 = mtf.one_hot(index_1, experts_dim, dtype=gates.dtype)

  if expert_capacity_dim.size > 256:
    # using mtf.cumsum (implemented on TPU as bfloat16 matmul) to compute
    # position in the mini-batch sent to the expert.  This will cause
    # very bad things to happen if expert_capacity_dim > 256.
    raise ValueError(
        "expert_capacity_dim.size must be <=256 to avoid roundoff errors in"
        " indices - got %s" % (expert_capacity_dim,))
  # [batch, group, experts]
  # This is the position within the expert's mini-batch for this sequence
  position_in_expert_1 = mtf.cumsum(mask_1, group_dim, exclusive=True) * mask_1
  # Remove the elements that don't fit. [batch, group, experts]
  mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
  # [batch, experts]
  # How many examples in this sequence go to this expert
  mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_dim)
  # [batch, group] - mostly ones, but zeros where something didn't fit
  mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
  # [batch, group]
  position_in_expert_1 = mtf.reduce_sum(
      position_in_expert_1, reduced_dim=experts_dim)
  # Weight assigned to first expert.  [batch, group]
  gate_1 *= mask_1_flat

  # Pick a second-place expert for each position.
  # We first mask out the experts that we expect to be over-capacity
  # [batch, experts]
  space_remaining = expert_capacity_f - mask_1_count
  use_rate = (mask_1_count + 1.0) / float(group_dim.size)
  # At what point in the sequence do we expect the expert to be full.
  # [batch, experts]
  expected_exhaustion_pos = space_remaining / use_rate
  # A Tensor with shape [batch, group, experts] representing a boolean
  #   - whether we expect that the expert will already be full.
  expected_exhausted = mtf.to_float(mtf.greater(
      mtf.range(gates.mesh, group_dim, tf.float32), expected_exhaustion_pos))
  masked_gates = gates - mask_1 - expected_exhausted
  # This section is similar to the section above.
  # [batch, group]
  index_2, gate_2 = mtf.top_1(masked_gates, experts_dim)
  # [batch, group, experts]
  mask_2 = mtf.one_hot(index_2, experts_dim, dtype=gates.dtype)
  # [batch, group, experts]
  position_in_expert_2 = (
      mtf.cumsum(mask_2, group_dim, exclusive=True) + mask_1_count)
  position_in_expert_2 *= mask_2
  mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
  # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
  mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
  position_in_expert_2 = mtf.reduce_sum(
      position_in_expert_2, reduced_dim=experts_dim)
  gate_2 *= mask_2_flat

  # renormalize the two gate values to add up to 1
  denom = gate_1 + gate_2 + 1e-9
  gate_1 /= denom
  gate_2 /= denom

  # [batch, group, experts, expert_capacity]
  assignment = (
      gate_1 * mask_1_flat
      * mtf.one_hot(index_1, experts_dim)
      * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
      gate_2 * mask_2_flat
      * mtf.one_hot(index_2, experts_dim)
      * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

  return assignment
示例#7
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf_layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf_layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
    def grow_topk(i, alive_seq, alive_log_probs, states=None):
        r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch, beam, length]
      alive_log_probs: probabilities of these sequences. [batch, beam]
      states: optional list of mtf.Tensor
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         list of transformed decoding states)
    """
        logits, new_states = logits_fn(i, alive_seq, states)
        batch_dim, beam_dim, vocab_dim = logits.shape.dims

        # Convert logits to normalized log probs
        candidate_log_probs = mtf.log_softmax(logits, vocab_dim)

        # Multiply the probabilities by the current probabilities of the beam.
        # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
        log_probs = candidate_log_probs + alive_log_probs

        length_penalty = mtf.pow(((5. + mtf.to_float(i + 1)) / 6.), alpha)

        curr_scores = log_probs / length_penalty

        # scores have shape [batch, beam, vocab]
        beam_and_vocab_dim = mtf.Dimension("beam_and_vocab",
                                           beam_dim.size * vocab_dim.size)
        flat_shape = mtf.Shape([batch_dim, beam_and_vocab_dim])
        double_beam = mtf.Dimension("double_beam", beam_dim.size * 2)
        # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
        flat_curr_scores = mtf.reshape(curr_scores, flat_shape)

        top_ids, top_scores = mtf.top_k(flat_curr_scores,
                                        reduced_dim=beam_and_vocab_dim,
                                        new_dim=double_beam)

        # Recovering the log probs because we will need to send them back
        top_log_probs = top_scores * length_penalty

        # Work out what beam the top probs are in.
        top_beam_index = top_ids // vocab_dim.size
        top_ids %= vocab_dim.size  # Unflatten the ids

        def my_gather(tensor):
            return mtf.gather(tensor,
                              top_beam_index,
                              beam_dim,
                              output_shape=mtf.Shape([
                                  double_beam if d == beam_dim else d
                                  for d in tensor.shape.dims
                              ]))

        # Gather up the most probable 2*beams both for the ids and finished_in_alive
        # bools
        top_seq = my_gather(alive_seq)

        if states:
            states = [my_gather(state) for state in new_states]

        # Append the most probable alive
        top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32)
        top_finished = mtf.equal(top_ids, eos_id)

        return top_seq, top_log_probs, top_scores, top_finished, states
示例#9
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  importance=None):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.softmax(
        mtf_layers.dense(inputs,
                         experts_dim,
                         use_bias=False,
                         expert_dims=outer_expert_dims), experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    density_1 = mtf.Print(
        density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])],
        "density_1",
        summarize=1000)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    policy = (hparams.moe_second_policy_train
              if train else hparams.moe_second_policy_eval)
    threshold = (hparams.moe_second_threshold_train
                 if train else hparams.moe_second_threshold_eval)
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)
    mask_2 = mtf.Print(mask_2,
                       [mtf.reduce_mean(mask_2, output_shape=[experts_dim])],
                       "density_2",
                       summarize=1000)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
示例#10
0
def transformer_moe_layer_v2(inputs, output_dim, hparams, train):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            importance=importance)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    # Now feed the expert inputs through the experts.
    hidden_output = mtf_layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     name="expert0")
    expert_output = mtf_layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     name="expert1")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
示例#11
0
文件: moe.py 项目: y12uc231/BERT-1
def _top_2_gating(inputs, experts_dim, expert_capacity_dim, max_experts,
                  hparams, train):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  max_experts is an float tensor with shape [batch_dim, group_dim]
  indicating at most how many experts to use per example.  This can be
  used to prevent padding from going to experts.

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned backward_assignment is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of all of these
  are as follows.

  inputs: [batch_dim, group_dim, input_dim]
  forward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim]
  expert_inputs: [batch_dim, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [batch_dim, experts_dim, expert_capacity_dim, output_dim]
  backward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim]
  outputs: [batch_dim, group_dim, output_dim]

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, group_dim, input_dim]
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    max_experts: optional mtf.Tensor with shape [batch_dim, group_dim]
    hparams: model hyperparameters.
    train: a boolean

  Returns:
    forward_assignment: a Tensor with shape
      [batch_dim, group_dim, experts_dim, expert_capacity_dim]
    backward_assignment: a Tensor with shape
      [batch_dim, group_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    unused_batch_dim, group_dim, unused_input_dim = inputs.shape.dims

    raw_gates = mtf.softmax(
        mtf_layers.dense(inputs, experts_dim, use_bias=False), experts_dim)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)

    if max_experts is not None:
        geq1 = mtf.to_float(mtf.greater_equal(max_experts, 1.0))
        geq2 = mtf.to_float(mtf.greater_equal(max_experts, 2.0))
        mask_1 *= geq1
        mask_2 *= geq2
        raw_gates *= geq1
        gates_without_top_1 *= geq2

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_dim)
    density_1 = mtf.Print(
        density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])],
        "density_1",
        summarize=1000)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    policy = (hparams.moe_second_policy_train
              if train else hparams.moe_second_policy_eval)
    threshold = (hparams.moe_second_threshold_train
                 if train else hparams.moe_second_threshold_eval)
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)
    mask_2 = mtf.Print(mask_2,
                       [mtf.reduce_mean(mask_2, output_shape=[experts_dim])],
                       "density_2",
                       summarize=1000)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (mtf.cumsum(mask_2, group_dim, exclusive=True) +
                            mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # renormalize the two gate values to add up to 1
    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # [batch, group, experts, expert_capacity]
    backward_assignment = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  backward_assignment.dtype)

    return forward_assignment, backward_assignment, loss