def unmask(h, m):
     with tf.name_scope('unmask'):
         tpu_summary.tensor('unmask_h', h)
         tpu_summary.tensor('unmask_m', m)
         t = tf.cumsum(m, -1) * m - 1
         mh = einsum_i32('bkt,bt->bkt', m, h)
         t2 = tf.one_hot(tf.cast(t, tf.int32),
                         output_len,
                         dtype=fprop_dtype)
         x = einsum_i32('bkt,bktT->bkT', mh, t2)
         return tf.cast(x, h.dtype)
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ P P P P P P - - - - - - P* - - - ]   ^
        # [ P P P P P P P P P P - - P* - - - ]   | batch
        # [ P - - - - - - - - - - - P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_pad = tf.cast(
            tf.less(tf.expand_dims(pfx_time, 0),
                    tf.expand_dims(pfx_len - 1, 1)), tf.int32)
        pfx_id = pfx * pfx_pad
        pfx_last = einsum_i32(
            'BT,BT->B', pfx, tf.one_hot(pfx_len - 1,
                                        pfx_max,
                                        dtype=fprop_dtype))

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_time * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
Beispiel #3
0
def Top2Gating(w,
               inputs,
               paddings,
               num_devices,
               experts_dim,
               expert_capacity_dim,
               local_dispatch,
               fprop_dtype,
               use_xla_sharding=True,
               second_expert_policy='all',
               second_expert_threshold=0.0,
               legacy_mtf_behavior=True,
               capacity_factor=None):
  """Computes Top-2 gating for Mixture-of-Experts.

  See Top2GatingOnLogits for more details.

  Note that for local_dispatch original batch BLM is reshaped into GSM, each
  group `g = 0...G-1` is being dispatched independently.

  Args:
    w: gating weights for each experts.
    inputs: G`SM Tensor.
    paddings: G`S Tensor.
    num_devices: number of MoE devices for local dispatch
    experts_dim: number of experts.
    expert_capacity_dim: number of examples per minibatch(group) per expert.
      Each example is typically a vector of size input_dim, representing
      embedded token or an element of Transformer layer output.
    local_dispatch: whether dispatch is local to the group (G dim)
    fprop_dtype: activations datatype to use.
    use_xla_sharding: bool, True if this function is used for the xla_sharding
        case.
    second_expert_policy: 'all' or 'random', we optionally 'random'-ize dispatch
      to second-best expert proportional to (weight / second_expert_threshold).
    second_expert_threshold: threshold for probability normalization for
      second_expert_policy == 'random'.
    legacy_mtf_behavior: True for legacy behavior with no re-normalization of
      expert assignment weights if we go over capacity or randomly decide to not
      dispatch to second expert.
    capacity_factor: if set, increases expert_capacity_dim to at least
      `(group_size * capacity_factor) / experts_dim`
      where `group_size` is the size of G dimension of `inputs`. If the
      value of expert_capacity_dim is already big enough no change is made.

  Returns:
    A tuple (dispatch_tensor, combine_tensor, aux_loss).

    - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to
      experts.
    - combine_tensor: G`SEC Tensor.
      combining expert outputs.
    - aux_loss: auxiliary loss, equalizing the expert assignment ratios.
  """
  orig_inputs = inputs
  if not local_dispatch:
    inputs = tf.reshape(inputs, [1, inputs.shape[0] * inputs.shape[1], -1])
    if paddings is not None:
      paddings = tf.reshape(paddings, [1, -1])

  logits = tf.einsum('GSM,ME->GSE', inputs, w)

  top1_expert_per_example = tf.math.argmax(logits, -1)

  tpu_summary.tensor('top1_expert', top1_expert_per_example)

  aux_loss, combine_tensor, dispatch_tensor = Top2GatingOnLogits(
      inputs, paddings, logits, num_devices, experts_dim, expert_capacity_dim,
      fprop_dtype, use_xla_sharding, second_expert_policy,
      second_expert_threshold, legacy_mtf_behavior, capacity_factor)

  if not local_dispatch:
    dispatch_tensor = tf.reshape(
        dispatch_tensor, orig_inputs.shape[:2] + dispatch_tensor.shape[2:])
    combine_tensor = tf.reshape(
        combine_tensor, orig_inputs.shape[:2] + combine_tensor.shape[2:])

  return py_utils.NestedMap(
      combine_tensor=combine_tensor,
      dispatch_tensor=dispatch_tensor,
      aux_loss=aux_loss)
    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state
Beispiel #5
0
def Top2GatingOnLogits(inputs,
                       paddings,
                       logits,
                       num_devices,
                       experts_dim,
                       expert_capacity_dim,
                       fprop_dtype,
                       use_xla_sharding=True,
                       second_expert_policy='all',
                       second_expert_threshold=0.0,
                       legacy_mtf_behavior=True,
                       capacity_factor=None):
  """Computes Top-2 gating for Mixture-of-Experts.

  There are two expected usages of this function:

  1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded
     tensor across multiple tpu cores. The operations within this function are
     automatically sharded/replicated across tpu cores.
  2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu
     core. All computations below are carried out on one tpu core only. This
     function tries to dispatch examples across tpu cores in such a way that
     each expert is assigned no more than 'expert_capacity_dim' number of
     examples.

  Below ` indicates common way of splitting along mesh dimension.

  Dimensions cheat sheet:

    G: group_dim
    S: group_size_dim
    E: number of experts
    C: capacity per expert
    M: model_dim (same as input_dim, same as output_dim)
    B: original batch_dim
    L: original sequence_length_dim

  Note that for local_dispatch original batch BLM is reshaped into GSM, each
  group `g = 0...G-1` is being dispatched independently.

  Args:
    inputs: G`SM Tensor.
    paddings: G`S Tensor.
    logits: G`SE Tensor.
    num_devices: number of MoE devices for local dispatch
    experts_dim: number of experts.
    expert_capacity_dim: number of examples per minibatch(group) per expert.
      Each example is typically a vector of size input_dim, representing
      embedded token or an element of Transformer layer output.
    fprop_dtype: activations datatype to use.
    use_xla_sharding: bool, True if this function is used for the xla_sharding
      case.
    second_expert_policy: 'all', 'sampling' or 'random'.

      - 'all': we greedily pick the 2nd expert.
      - 'sampling': we sample the 2nd expert from the softmax.
      - 'random': we optionally 'random'-ize dispatch to second-best expert
        proportional to (weight / second_expert_threshold).

    second_expert_threshold: threshold for probability normalization for
      second_expert_policy == 'random'.
    legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly.
    capacity_factor: if set, increases expert_capacity_dim to at least
      (group_size * capacity_factor) / experts_dim
      where `group_size` is the size of G dimension of `inputs`. If the
      value of expert_capacity_dim is already big enough no change is made.

  TODO(lepikhin): get rid of the legacy_mtf_behavior flag.

  Returns:
    A tuple (aux_loss, combine_tensor, dispatch_tensor).

    - aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
    - combine_tensor: G`SEC Tensor for combining expert outputs.
    - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to
      experts.
  """
  del inputs  # inputs is currently not used.
  raw_gates = tf.nn.softmax(logits)  # along E dim

  if capacity_factor is not None:
    # Determine expert capacity automatically depedning on the input size.
    group_size_dim = int(logits.shape[1])
    auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim)
    if expert_capacity_dim < auto_expert_capacity:
      expert_capacity_dim = auto_expert_capacity
      # Round up to a multiple of 4 to avoid possible padding.
      while expert_capacity_dim % 4:
        expert_capacity_dim += 1
      tf.logging.info(
          'Setting expert_capacity_dim=%r (capacity_factor=%r '
          'group_size_dim=%r experts_dim=%r name_scope=%r)',
          expert_capacity_dim, capacity_factor, group_size_dim, experts_dim,
          tf.get_default_graph().get_name_scope())
    tpu_summary.scalar('expert_capacity', expert_capacity_dim)

  # top first and second gate value and expert index for each input
  #
  # GSK Tensors, K=2
  def _MaybeSplit(x):
    if use_xla_sharding:
      return Split(x, 0, num_devices)
    else:
      return x

  def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
    over_capacity = tf.reduce_sum(
        tf.cast(
            tf.greater_equal(mask * position_in_expert, capacity), mask.dtype))
    over_capacity_ratio = over_capacity / tf.reduce_sum(mask)
    py_utils.AddTpuSummaryTensor(name, over_capacity_ratio)
    tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean')

  # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@
  # and krikun@ will:
  #   - expand moe_spmd_test to compare Adafactor updates, slots on TPU
  #   including 2x2 with sharding
  #
  #   - add more tests for policy="random"
  #
  #   - add single step test for full size WMT model on CPU
  #
  # and then break this function into modules.
  #
  # GS
  index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
  index_1 = _MaybeSplit(index_1)
  tpu_summary.tensor('index_1', index_1)

  # GSE
  mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
  mask_1 = _MaybeSplit(mask_1)
  density_1_proxy = raw_gates

  importance = tf.ones_like(mask_1[:, :, 0])

  if paddings is not None:
    importance = 1.0 - paddings
    mask_1 *= tf.expand_dims(importance, -1)
    density_1_proxy *= tf.expand_dims(importance, -1)

  gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
  gates_without_top_1 = raw_gates * (1.0 - mask_1)

  if second_expert_policy == 'sampling':
    # We directly sample the 2nd expert index from the softmax over of the 2nd
    # expert by getting rid of the 1st expert already selected above. To do so,
    # we set a very negative value to the logit corresponding to the 1st expert.
    # Then we sample from the softmax (categorical) distribution using the
    # Gumbel max trick.
    noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype))
    # Generates standard Gumbel(0, 1) noise, GSE Tensors
    noise = -tf.math.log(-tf.math.log(noise))
    very_negative_logits = _MaybeSplit(
        (tf.ones_like(logits) * logits.dtype.max *
         tf.constant(-0.7, dtype=logits.dtype)))
    # Gets rid of the first expert by setting its logit to be very negative
    updated_logits = _MaybeSplit(
        tf.where(mask_1 > 0.0, very_negative_logits, logits))
    # Adds the Gumbel noise to the updated logits
    noised_logits = _MaybeSplit(updated_logits + noise)
    # Picks the index of the largest noised logit as the 2nd expert. This is
    # equivalent to sampling from the softmax over the 2nd experts.
    index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32)
  else:
    index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32)

  index_2 = _MaybeSplit(index_2)
  mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype)
  mask_2 = _MaybeSplit(mask_2)
  if paddings is not None:
    mask_2 *= tf.expand_dims(importance, -1)
  gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2)

  if legacy_mtf_behavior:
    # cl/298510175 moved this branch for gate_{1,2} denom calculation here.
    #
    # For policy=random, it's better to nomalize gate_{1,2} before taking
    # capacity  into account and before potentially dropping second expert.
    #
    # According to mean_xent (http://short/_NzbZ5rINr5):
    #   MoE_512_102xen_PolicyAll_298510175
    #   MoE_512_102xen_PolicyRandom_298510175
    #
    # vs pre-cl/298510175
    #   MoE_512_102xen_PolicyRandom
    #   MoE_512_102xen_PolicyAll
    #
    # it substantially improves policy=random with threshold=0.5 which
    # historically was better than policy="all"
    #
    # Also confirmed this by decoding
    #   nmt_train/m4/data/es_en/test.txt
    #   nmt_train/m4/data/ru_en/test.txt
    #   nmt_train/m4/data/zh_en/test.txt
    # and improving BLEU
    #
    # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.421443
    #   0.327102
    #   0.315693
    # vs
    # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.399232
    #   0.310606
    #   0.288229
    #
    # Additional comparison, see mean_xent http://short/_YHccOhQtdu with
    # legacy_mtf_behavior=False models
    #   3 - MoE_512_102xen_PolicyAll_LegacyFalse
    #   6 - MoE_512_102xen_PolicyRandom_LegacyFalse
    # shows that policy="random" gets worse with legacy_mtf_behavior=False, and
    # is similar to pre-cl/298510175
    #   4 - MoE_512_102xen_PolicyRandom
    #
    # gate_1 can become 0 due to Expert being out of capacity.
    #
    # gate_2 can become 0 due to
    #   second_expert_policy == 'random'
    # or "out of capacity" scenario.
    #
    # Here we renormalize regardless of cases above.
    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

  # We reshape the mask as [X*S, E], and compute cumulative sums of
  # assignment indicators for each expert index e \in 0..E-1 independently.
  # First occurrence of assignment indicator is excluded, see exclusive=True
  # flag below.
  position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1)

  # GS Tensor
  capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype)

  # GE Tensor (reducing S out of GSE tensor mask_1)
  # density_1[:, e] represents assignment ratio (num assigned / total) to
  # expert e as top_1 expert without taking capacity into account.
  if legacy_mtf_behavior:
    density_denom = 1.0
  else:
    density_denom = tf.reduce_mean(
        importance, axis=(1))[:, tf.newaxis] + 1e-6
  density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom
  # density_1_proxy[:, e] represents mean of raw_gates for expert e, including
  # those of examples not assigned to e with top_k.
  density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom

  # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of
  # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of
  # the density_1_proxy with the discrete density_1 following
  # mesh_tensorflow/transformer/moe.py?rcl=283569345.
  aux_loss = tf.reduce_mean(density_1_proxy * density_1)  # element-wise
  aux_loss *= experts_dim * experts_dim  # const coefficient

  # Add the over capacity ratio for expert 1
  _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity,
                                  'over_capacity_1_ratio')

  mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype)
  position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1)

  # How many examples in this sequence go to this expert
  mask_1_count = tf.einsum('GSE->GE', mask_1)
  # [batch, group] - mostly ones, but zeros where something didn't fit
  mask_1_flat = tf.einsum('GSE->GS', mask_1)

  if second_expert_policy == 'all' or second_expert_policy == 'sampling':
    pass
  elif second_expert_policy == 'random':
    # gate_2 is between 0 and 1, reminder:
    #
    #   raw_gates = tf.nn.softmax(logits)
    #   index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
    #   mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
    #   gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
    #
    # E.g. if gate_2 exceeds second_expert_threshold, then we definitely
    # dispatch to second-best expert. Otherwise we dispatch with probability
    # proportional to (gate_2 / threshold).
    #
    sampled_2 = tf.less(
        _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)),
        (gate_2 / max(second_expert_threshold, 1e-9)))
    gate_2 *= tf.cast(sampled_2, gate_2.dtype)
    mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype)
  else:
    raise ValueError(second_expert_policy)

  position_in_expert_2 = tf.cumsum(
      mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1)

  # Add the over capacity ratio for expert 2
  _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity,
                                  'over_capacity_2_ratio')

  mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype)
  position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2)
  mask_2_flat = tf.reduce_sum(mask_2, axis=-1)

  # Equivalent non-einsum implementation:
  #
  # position_in_expert_2 *= mask_2
  # position_in_expert_2 = tf.reduce_sum(
  #     position_in_expert_2, axis=-1, name='position_in_expert_2')

  gate_1 *= mask_1_flat
  gate_2 *= mask_2_flat

  if not legacy_mtf_behavior:
    denom = gate_1 + gate_2
    # To avoid divide by 0.
    denom = tf.where(denom > 0, denom, tf.ones_like(denom))
    gate_1 /= denom
    gate_2 /= denom

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_1, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_0')
  # GSE Tensor
  a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot(
      index_1, experts_dim, dtype=fprop_dtype)
  # GSEC Tensor
  first_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor')

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_2, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_1')
  # GSE Tensor
  a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot(
      index_2, experts_dim, dtype=fprop_dtype)
  second_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor')

  # GSEC Tensor
  combine_tensor = (
      first_part_of_combine_tensor + second_part_of_combine_tensor)
  combine_tensor = _MaybeSplit(combine_tensor)

  # GSEC Tensor
  dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype)
  dispatch_tensor = _MaybeSplit(dispatch_tensor)

  # TODO(yonghui): compute and return per-group aux_loss.
  return aux_loss, combine_tensor, dispatch_tensor