Beispiel #1
0
    def _MergeCandidates(tokens, candidates):
      """Merge in the reverse binary tree."""
      best_id = tf.argmin(candidates, output_type=tf.int32)
      # Perform the merge at position best_id.
      tokens = tf.concat(
          [tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:]],
          axis=0)
      # Recompute the merge candidates.
      # Only the neighbors of best_id need to be recomputed.
      empty = tf.zeros([0], dtype=candidates.dtype)

      def _MergeLeft():
        return tf.concat(
            [candidates[:best_id - 1],
             _MergeOneToken(tokens, best_id - 1)],
            axis=0)

      left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft)

      def _MergeRight():
        return tf.concat(
            [_MergeOneToken(tokens, best_id), candidates[best_id + 2:]], axis=0)

      right_candidates = tf.cond(
          tf.greater_equal(best_id,
                           tf.size(tokens) - 1), lambda: empty, _MergeRight)

      candidates = tf.concat([left_candidates, right_candidates], axis=0)
      return tokens, candidates
def ComputeSplits(batch_size, num_splits):
    """Creates a tensor of size num_splits of number of values per split.

  Assigns each split floor(batch_size/num_splits) and round-robins
  the remainder (if any) to each split.

  Example::

    batch_size: [5]
    num_splits: 3
    returns: [2, 2, 1]

  Args:
    batch_size: tensor of rank 0, size of tensor to be split
    num_splits: number of splits to split tensor into
  Returns:
    tensor of length num_splits containing sizes of each split
  """
    values = tf.tile(tf.div([batch_size], num_splits),
                     tf.constant([num_splits], dtype=tf.int32))
    mods = tf.tile(tf.constant([1]), tf.math.floormod([batch_size],
                                                      num_splits))
    zeros = tf.tile(tf.constant([0]),
                    tf.subtract(tf.shape(values), tf.shape(mods)))
    mods = tf.concat([mods, zeros], 0)
    ret = tf.add(values, mods)
    # for some reason TF erases shape information if num_splits is 1
    if num_splits == 1:
        ret.set_shape([1])
    return ret
Beispiel #3
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """Performs inference on N steps at once and concatenates the result.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: An output from PrepareExternalInputs.
      step_inputs: A `.NestedMap` containing a list called 'inputs'.
      padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this
        batch element is empty in this step.
      state0: The previous recurrent state.

    Returns:
      A tuple (output, state1):

      - output: A `.NestedMap` containing the output of the top-most step.
      - state1: The recurrent state to feed to next invocation of this graph.
    """
        state1 = py_utils.NestedMap(sub=[None] * len(self.sub))
        outputs = [None] * len(self.sub)

        for i in range(len(self.sub)):
            outputs[i], state1.sub[i] = self.sub[i].FProp(
                theta.sub[i], prepared_inputs.sub[i], step_inputs, padding,
                state0.sub[i])

        output = py_utils.NestedMap(output=tf.concat(outputs, axis=1))
        return output, state1
Beispiel #4
0
def AddMultiCurveSubplot(fig,
                         tensors,
                         paddings,
                         labels,
                         xlabels=None,
                         **kwargs):
    """Adds a multi curve subplot to Matplotlib figure.

  Plots one line for each entry in tensors and assigns a plot label legend.

  Args:
    fig: The Matplotlib figure.
    tensors: List of tensors of shape [batch, length]
    paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid
      positions and 1. in invalid.
    labels: A list of tensor names (strings) of the same length as 'tensors'.
    xlabels: A string tensor of shape [batch] with an xlabel per batch.
    **kwargs: With optional, title, xlabel, ylabel, fontsize.
  """
    data = []
    row_labels = []
    for t, l in zip(tensors, labels):
        if t is not None:
            data.append(py_utils.ApplyPadding(paddings, t))
            row_labels.append(l)
    shape = py_utils.GetShape(data[0], 2)
    data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]])

    args = [data, py_utils.LengthsFromPaddings(paddings)]
    if xlabels is not None:
        args.append(xlabels)
    fig.AddSubplot(args,
                   plot_func=_AddMultiCurveRowPlots,
                   row_labels=row_labels,
                   **kwargs)
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices."""
    num_outfeeds = len(self.metrics_nm.Flatten())
    outfeed_dicts = []
    concat_lists = {}
    # Hard-coding for Transformer/MLPerf.
    keys = ['target_ids', 'eval_weight', 'tlen', 'top_ids', 'top_lens']
    concat_dict = {}
    for key in keys:
      concat_lists[key] = []

    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    for replica in range(device_assignment.num_replicas):
      num_cores_per_replica = 1 if self.spmd else (
          device_assignment.num_cores_per_replica)
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
              shapes=[x.shape for x in self.metrics_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          packed = tf.nest.pack_sequence_as(self.metrics_nm, outfeeds_per_core)
          outfeed_dict = self._decode_model_task.PostProcessDecodeHost(packed)
          for key in keys:
            concat_lists[key].append(outfeed_dict[key])

    for key in keys:
      concat_dict[key] = tf.concat(concat_lists[key], 0)
    return concat_dict
Beispiel #6
0
        def _DerivePaddingsAndIds(src_ids, tgt_labels):
            """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended."""
            tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0)
            src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32)
            tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32)
            tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32)

            bucket_key = tf.cast(
                tf.maximum(tf.reduce_sum(1.0 - src_paddings),
                           tf.reduce_sum(1.0 - tgt_paddings)), tf.int32)

            return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key
Beispiel #7
0
def ExtractBlockContext(x,
                        block_size,
                        left_context,
                        right_context,
                        padding_val=0.0):
  """Extracts temporal context for every block.

  Args:
    x: a tensor of [batch, time, ...].
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    padding_val: float. value on the padded frames.

  Returns:
    A tensor of [batch, num_blocks, context_size, ...], with necessary paddings,
    where context_size = block_size + (left_context - 1) + right_context,
    and output[:, i, ...] are x[:, start-left_context+1:end+right_context, ...],
    start = i * block_size, end = (i + 1) * block_size.
  """
  if block_size < 1:
    raise ValueError('block_size must be at least 1, got {}'.format(block_size))
  if left_context < 1 or left_context > block_size + 1:
    raise ValueError(
        'left_context must be at least 1 and at most block_size + 1 = {}, '
        'got {}'.format(block_size + 1, left_context))
  if right_context < 0 or right_context > block_size:
    raise ValueError(
        'right_context must be at least 0 and at most block_size = {}, '
        'got {}'.format(block_size, right_context))

  block = ConvertToBlocks(x, block_size, padding_val)
  concat_list = [block]

  if left_context > 1:
    if block_size == left_context - 1:
      left_block = tf.roll(block, shift=1, axis=1)
    else:
      x_shift = tf.roll(x, shift=left_context - 1, axis=1)
      x_shift_block = ConvertToBlocks(x_shift, block_size, padding_val)
      left_block = x_shift_block[:, :, :left_context - 1:, ...]
    concat_list = [left_block] + concat_list

  if right_context > 0:
    if block_size == right_context:
      right_block = tf.roll(block, shift=-1, axis=1)
    else:
      x_shift = tf.roll(x, shift=-right_context, axis=1)
      x_shift_block = ConvertToBlocks(x_shift, block_size, padding_val)
      right_block = x_shift_block[:, :, -right_context:, ...]
    concat_list += [right_block]

  return tf.concat(concat_list, axis=2)
Beispiel #8
0
  def FProp(self, theta, inputs, paddings):
    """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
    p = self.params
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(inputs),
              tf.concat([
                  tf.shape(paddings),
                  [-1, symbolic.ToStatic(self.input_channels)]
              ], 0))
      ], inputs)

      def _ApplyPadding(tensor_in, padding_in):
        padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1)
        return tensor_in * (1.0 - padding_expanded)

      # Zeroing out padded inputs.
      inputs = _ApplyPadding(inputs, paddings)

      # Apply conv on 'inputs'.
      out = self._ApplyConv(theta, inputs)

      if p.partial_conv:
        out = self._RescaleBoundary(out, paddings)
      # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
      # But there's likely no real problems. Trying to set it gives an error:
      # pooling with SAME padding is not implemented for dilation_rate > 1.
      # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
      # implementation.  Consider updating it to be the actual shape.
      conv_padding = ComputeConvOutputPadding(
          paddings, window=p.filter_stride[0], stride=p.filter_stride[0])
      # Assuming padded nodes will be properly zero-ed out if necessary by
      # sub-sequent layers.
      # out = _ApplyPadding(out, conv_padding)
      out = py_utils.HasShape(
          out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
      return out, conv_padding
Beispiel #9
0
    def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta):
      """Tokenizes a single sentence."""
      ids, _ = self._wpm_encoder.Encode(strs[i])

      if append_eos:
        ids = tf.concat([ids, [self.eos_id]], axis=0)

      # This truncates after the eos is added, so some sentences might
      # not have </s> at the end.
      token_ids_ta = token_ids_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.concat([[self.sos_id], ids], axis=0), [max_length],
              self.eos_id))
      target_ids_ta = target_ids_ta.write(
          i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id))
      paddings_ta = paddings_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.))

      return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta
Beispiel #10
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """Performs inference on the stack of sub-steps.

    There are three possible ways to feed input to the stack:

      * step_inputs.inputs: These tensors are fed only to the lowest layer.
      * step_inputs.context: [Optional] This tensor is fed to every layer.
      * prepared_inputs: [Optional] This tensor is fed to every layer and
          is assumed to stay constant over all steps.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: An output from PrepareExternalInputs.
      step_inputs: A `.NestedMap` containing a list called 'inputs', an
        optionally a tensor called 'context'.
      padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this
        batch element is empty in this step.
      state0: The previous recurrent state.

    Returns:
      A tuple (output, state1):

      - output: A `.NestedMap` containing the output of the top-most step.
      - state1: The recurrent state to feed to next invocation of this graph.
    """
        state1 = py_utils.NestedMap(sub=[])
        inputs = list(step_inputs.inputs)
        # We pretend that the input is the output of layer -1 for the purposes
        # of residual connections.
        residual_inputs = [tf.concat(inputs, axis=1)]
        additional = []
        if 'context' in step_inputs:
            additional.append(step_inputs.context)
        for i in range(len(self.sub)):
            sub_inputs = py_utils.NestedMap(inputs=inputs + additional)
            sub_output, state1_i = self.sub[i].FProp(theta.sub[i],
                                                     prepared_inputs.sub[i],
                                                     sub_inputs, padding,
                                                     state0.sub[i])
            state1.sub.append(state1_i)
            output = sub_output.output
            if i >= self.params.residual_start >= 0:
                # residual_inputs contains the step input at residual_inputs[0].
                assert i + 1 - self.params.residual_stride < len(
                    residual_inputs)
                output += residual_inputs[i + 1 - self.params.residual_stride]
            residual_inputs.append(output)
            inputs = [output]
        return py_utils.NestedMap(output=output), state1
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
 def _OutfeedDequeue(self):
   """Collect outfeed dequeue from all devices."""
   num_outfeeds = len(self.metrics_nm.Flatten())
   outfeed_ops = [[]] * num_outfeeds
   device_assignment = py_utils.GetTpuDeviceAssignment()
   assert device_assignment
   for replica in range(device_assignment.num_replicas):
     num_cores_per_replica = 1 if self.spmd else (
         device_assignment.num_cores_per_replica)
     for core in range(num_cores_per_replica):
       with tf.device(device_assignment.host_device(replica, core)):
         outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
             dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
             shapes=[x.shape for x in self.metrics_nm.Flatten()],
             device_ordinal=device_assignment.tpu_ordinal(replica, core))
         for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
           outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
   return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
Beispiel #13
0
 def _Slice(tensor):
     """Return a slice of this tensor at time=state0.t."""
     shape = py_utils.GetShape(tensor)
     # All zeros except for t in the time dimension.
     # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
     begin = tf.one_hot(self.params.axis,
                        tf.rank(tensor),
                        on_value=state0.t)
     # Same as shape, but with a 1 in the time dimension.
     # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
     size = tf.concat([
         shape[0:self.params.axis],
         tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
     ],
                      axis=0)
     # Make a slice where the time dimension is fixed at state0.t.
     time_slice = tf.slice(tensor, begin, size)
     # Remove the time dimension.
     return tf.squeeze(time_slice, axis=self.params.axis)
Beispiel #14
0
    def ProjectInputSequence(self, theta, inputs):
        """Applies input projection for the entire sequence.

    Args:
      theta: a NestedMap of layer weights. Notably, it's expected to contain
        separate weight tensors for input and hidden state projections, for
        performance reasons, under the key 'wm_i' (input) and 'wm_h'
      inputs: A NestedMap with the following fields:
        - act: A list of Tensors of shape [seqlen, batch, input_dim].

    Returns:
      A Tensor of shape [seqlen, batch, 4 * hidden_dim].
    """
        assert isinstance(inputs.act, list)
        if len(inputs.act) > 1:
            x = tf.concat(inputs.act, -1)
        else:
            x = inputs.act[0]
        # [T, B, 4 * H]
        proj_inputs = tf.einsum('TBD,DH->TBH', x, theta.wm_i)
        return proj_inputs
  def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
    """Produces a context vector from the attention algorithm.

    The context vector is a summary of the inputs from external_inputs
    which the attention algorithm has determined would be useful for decoding
    the next output.

    Args:
      theta: A NestedMap containing weights' values of this layer and its
        children layers.
      prepared_inputs: A set of encoded tensors that have been pre-processed by
        PrepareExternalInputs.
      step_inputs: A NestedMap containing an 'inputs' tensor with the query
        vector to use.
      padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch
        slot is not used.
      state0: A NestedMap of state, either produced by ZeroState or a previous
        invocation of this graph.

    Returns:
      output, state1, defined as follows:
      - output: a NestedMap containing a query tensor, a context tensor, and
        cum_atten_probs, the log of attention probabilities for each input
        vector.
      - state1: a NestedMap of state to be used in subsequent invocations of
        this graph.
    """
    (new_atten_context, new_atten_probs,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.concat(step_inputs.inputs, axis=1),
         attention_state=state0.atten_state)
    new_atten_probs = py_utils.ApplyPadding(padding, new_atten_probs)
    output = py_utils.NestedMap(
        context=new_atten_context, probs=new_atten_probs)
    state1 = py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
    return output, state1
Beispiel #16
0
 def _MergeRight():
   return tf.concat(
       [_MergeOneToken(tokens, best_id), candidates[best_id + 2:]], axis=0)
Beispiel #17
0
 def _MergeLeft():
   return tf.concat(
       [candidates[:best_id - 1],
        _MergeOneToken(tokens, best_id - 1)],
       axis=0)
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Beispiel #19
0
    def _CreateCanvasAndTargets(self, batch):
        # pyformat: disable
        """Create the canvas and targets.

    Args:
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    """
        # pyformat: enable
        p = self.params

        if not self.do_eval:
            # Sample our src and tgt canvas.
            src_descriptor = self._SampleCanvasAndTargets(
                batch.src.ids, batch.src.paddings)
            tgt_descriptor = self._SampleCanvasAndTargets(
                batch.tgt.ids, batch.tgt.paddings)

            # Offset the src ids (to unshare embeddings between src/tgt). Note, we
            # only offset the canvas ids, but we do not offset the vocab ids. This
            # will result in unshared embeddings, but shared softmax. This is due to
            # GPU/TPU memory limitations, empirically it is known that unsharing
            # everything results in better performance.
            vocab_size = p.decoder.softmax.num_classes
            src_descriptor.canvas = tf.where(
                tf.equal(src_descriptor.canvas_paddings, 0),
                src_descriptor.canvas + vocab_size, src_descriptor.canvas)

            # Offset the tgt indices (need shift according to src length).
            batch_size = py_utils.GetShape(batch.src.ids)[0]
            # `target_batch` is a [num_targets, batch_size] tensor where each row
            # identifies which batch the target belongs to. Note the observation that,
            # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
            target_batch = tf.cast(
                tf.equal(
                    tf.expand_dims(tf.range(batch_size), 0),
                    tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)),
                tf.int32)
            src_lens = tf.cast(
                tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
            # `tgt_offset` is shape [num_targets] where each entry corresponds to the
            # offset needed for that target (due to the source length).
            tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
            # We shift the tgt slot without touching the batch or vocab.
            tgt_descriptor.target_indices += tf.concat([
                tf.zeros_like(tgt_offset), tgt_offset,
                tf.zeros_like(tgt_offset)
            ], 1)

            # The canvas is simply the sequence-level concat of the src and tgt.
            canvas, canvas_paddings = insertion.SequenceConcat(
                src_descriptor.canvas, src_descriptor.canvas_paddings,
                tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
            target_indices = tf.concat(
                [src_descriptor.target_indices, tgt_descriptor.target_indices],
                0)
            target_weights = tf.concat(
                [src_descriptor.target_weights, tgt_descriptor.target_weights],
                0)

            return py_utils.NestedMap(canvas=canvas,
                                      canvas_paddings=canvas_paddings,
                                      target_indices=target_indices,
                                      target_weights=target_weights)
def beam_search_step(in_scores,
                     in_atten_probs,
                     in_best_scores,
                     in_cumulative_scores,
                     in_histories,
                     cur_step,
                     eos_id,
                     num_beams,
                     beam_size,
                     num_hyps_per_beam,
                     valid_eos_max_logit_delta=5.0,
                     local_eos_threshold=-100.0,
                     merge_paths=False,
                     is_last_chunk=None,
                     eoc_id=0):
    """A single step of beam search.

  Let "b" be the number of beams, "k" be the number hyps in each beam. This
  function supports values with dtypes tf.float32 or tf.bfloat16.

  The following data structures are allocated before the first decoding step and
  are passed along from cur step to the next step:

  Args:
    in_scores: A tensor of shape [b * k, vocab_size], where [i, ...] is the
      token score of the j-th hyps of the n-th beam. j = (i / k), and n = i % k
    in_atten_probs: A tensor of shape [b*k, s_len], where in_atten_probs[i, ...]
      is the attention probabilities over the source words of the j-th hyps of
      n-th beam (where j, and n are derived as above).
    in_best_scores: A vector of size [b], best scores of terminated hyps so far
      in each of the beams.
    in_cumulative_scores: A vector of size [b * k]. The cumulative score of each
      active hyp before the current step.
    in_histories: An int32 vector of size [b * k] containing hashes of the
      histories of each active hyp. If 'merge_paths' is enabled, the histories
      are used to identify hypotheses that are identical modulo epsilons (e.g.
      "a <eps> b" and "a b <eps>") and merge them. See 'update_histories'
      docstring for details.
    cur_step: Current step id.
    eos_id: Token id of the special end of sequence token.
    num_beams: Number of beams.
    beam_size: Search terminates if the delta between the scores of the active
      hyps.
    num_hyps_per_beam: Number of hyps in a beam.
    valid_eos_max_logit_delta: We allow </s> to terminate a hyp only if its
      logit is no more than 'valid_eos_max_logit_delta' away from the logit of
      the best candidate.
    local_eos_threshold: We allow </s> to terminate a hyp if the local score for
      </s> is greater than local_eos_threshold.
    merge_paths: If true, hyps which are identical when epsilons are removed
      will be combined into a single hyp.  The probability for that combined hyp
      will be the sum of the probabilities of the component hyps.  This can only
      be applied for epsilon-emitting models (RNN-T and NT).
    is_last_chunk: A tensor of shape [b * k, 1]. Used by neural transducer,
      determines whether the current hypothesis reaches the last chunk and
      should treat the next end-of-chunk symbol as end-of-sentence.
    eoc_id: int, the id of the end of chunk (a.k.a epsilon) token used by neural
      transducer models. Only relevant if 'merge_paths' is True or
      'is_last_chunk' is provided.

  Returns:
    out_best_scores: A tensor of shape [b] of updated best scores for each of
      the beams.
    out_cumulative_scores: A tensor of shape [b * k]. The cumulative score of
      the new hyps after the current decoding step.
    out_scores: A tensor of shape [b * k] with scores of the token selected.
    out_eos_scores: A tensor of shape [b * k] with token scores for the EOS, in
      case the hyp was terminated, otherwise 0.0.
    out_hyps: A tensor of shape [b * k] with ids of the token selected.
    out_prev_hyps: A tensor of shape [b * k] with index to the previous hyps
      which was selected.
    out_done_hyps: A boolean tensor of shape [b * k] where value indicates
      if hyps was terminated.
    out_atten_probs: A tensor of shape [b * k, seq_len] which contain the
      attention probabilities over the source words against word in the previous
      hyps.
    out_eos_atten_probs: A tensor of shape [b * k, seq_len] which contains the
      attention probabilities over the source against word in the current hyp
      which was terminated.
    out_all_done: A scalar, whether decoding should terminate for all beams.
    out_histories: A tensor of shape [b * k] containing new history hashes for
      the active hypotheses. See 'update_histories' docstring for details.
  Raises:
    ValueError: if inputs are invalid.
  """
    num_hyps_per_beam = int(num_hyps_per_beam)

    if num_hyps_per_beam <= 0:
        raise ValueError("num_hyps_per_beam = {} and must be > 0.".format(
            num_hyps_per_beam))

    in_scores = tf.convert_to_tensor(in_scores)
    in_scores.shape.assert_has_rank(2)
    num_classes = in_scores.get_shape()[1]

    in_atten_probs = tf.convert_to_tensor(in_atten_probs)
    in_atten_probs.shape.assert_has_rank(2)

    in_best_scores = tf.convert_to_tensor(in_best_scores)
    in_best_scores.shape.assert_has_rank(1)

    in_cumulative_scores = tf.convert_to_tensor(in_cumulative_scores)
    in_cumulative_scores.shape.assert_has_rank(1)

    in_histories = tf.convert_to_tensor(in_histories)
    in_histories.shape.assert_has_rank(1)

    with tf.name_scope("beam_search_step"):
        # For k = num_hyps_per_beam
        # First step of beam search is to find the top tokens based on its score.
        # Normally we select k+1, where the extra +1 is to make sure we have k
        # non-eos tokens to select if EOS token is in the top-k. If path merging is
        # on, we actually need to select k+2; this ensures there are k+1 tokens left
        # after the merge, at least k of which are not EOS.
        # TODO(b/118644069): Avoid casts when there is a XLA op available that takes
        # in bfloat16.
        num_candidates_per_input_hyp = (num_hyps_per_beam + 2 if merge_paths
                                        else num_hyps_per_beam + 1)
        # [b * k, num_candidates_per_input_hyp]
        local_score_values, local_indices = xla_ops.top_k_with_unique(
            tf.cast(in_scores, tf.float32), k=num_candidates_per_input_hyp)
        local_score_values = tf.cast(local_score_values, in_scores.dtype)

        # Compute the global score which is sum of the local score, and the
        # cumulative scores for each of the hyps.
        # [b * k, num_candidates_per_input_hyp]
        global_score_values = local_score_values + tf.expand_dims(
            in_cumulative_scores, 1)

        values_dtype = local_score_values.dtype
        is_first_step = tf.cast(tf.equal(cur_step, 0), values_dtype)

        # Preprocessing to reorder the tensor from `mod` sharding to `div` so that
        # we can use matrix/vector operations to complete the beam search.
        # [b * k, num_candidates_per_input_hyp]
        global_score_values = reorder_tensor("mod_to_div", global_score_values,
                                             num_beams, num_hyps_per_beam)
        local_score_values = reorder_tensor("mod_to_div", local_score_values,
                                            num_beams, num_hyps_per_beam)
        local_indices = reorder_tensor("mod_to_div",
                                       local_indices,
                                       num_beams,
                                       num_hyps_per_beam,
                                       max_value=num_classes - 1)
        # [b * k, 1]
        histories = reorder_tensor("mod_to_div",
                                   tf.expand_dims(in_histories, 1), num_beams,
                                   num_hyps_per_beam)
        if is_last_chunk is None:
            is_last_chunk = tf.zeros([num_beams * num_hyps_per_beam, 1],
                                     tf.bool)
        else:
            is_last_chunk = tf.cast(
                reorder_tensor(
                    "mod_to_div",
                    tf.reshape(is_last_chunk,
                               [num_beams * num_hyps_per_beam, 1]), num_beams,
                    num_hyps_per_beam), tf.bool)

        # For the first step mask everything but the first row.
        # [num_hyps_per_beam]
        per_example_mask = tf.concat([
            tf.constant([1.0], dtype=values_dtype),
            tf.zeros([num_hyps_per_beam - 1], dtype=values_dtype)
        ], 0)
        # [num_hyps_per_beam, num_beams] => [b*k, 1]
        mask = tf.reshape(
            tf.tile(per_example_mask, tf.expand_dims(num_beams, 0)),
            [-1, 1]) * is_first_step + (1.0 - is_first_step)
        local_score_values *= mask
        global_score_values *= mask

        # We add a large negative value for the unmasked values.
        per_example_additive_mask = tf.concat([
            tf.constant([0.0], dtype=values_dtype),
            tf.constant(BEST_SCORES_INIT,
                        shape=[num_hyps_per_beam - 1],
                        dtype=values_dtype)
        ], 0)
        additive_mask = tf.reshape(
            tf.tile(per_example_additive_mask, tf.expand_dims(num_beams, 0)),
            [-1, 1]) * is_first_step
        local_score_values += additive_mask
        global_score_values += additive_mask

        if merge_paths:
            with tf.name_scope("merge_paths"):
                # Compute new history hashes for each hypothesis + new token.
                # [b * k, num_candidates_per_input_hyp]
                histories = update_histories(histories,
                                             local_indices,
                                             mask,
                                             epsilon_id=eoc_id)
                global_score_values, histories = merge_hyps(
                    global_score_values, histories, mask, num_beams,
                    num_hyps_per_beam)

        # As we keep num_candidates_per_input_hyp, we have a total of
        # num_candidates_per_input_hyp * k hyps active per example.
        num_candidate_hyps = num_candidates_per_input_hyp * num_hyps_per_beam
        batch_shape = [-1, num_candidate_hyps]

        # Reshape score values so that each row corresponds to a particular example.
        # [num_beams, num_candidate_hyps]
        global_score_values_batch = tf.reshape(global_score_values,
                                               batch_shape)

        # First for each beam: Find the top 2 * num_hyps_per_beam candidates.
        # The factor of 2 is to be able to process non EOS token ids in the case
        # where top scoring token for each hyps is EOS token.
        # [k * b, 2 * num_hyps_per_beam]
        _, candidates_indices_in_top_k = xla_ops.top_k_with_unique(
            tf.cast(global_score_values_batch, tf.float32),
            k=2 * num_hyps_per_beam)
        # Find the previous hyps of the candidate. We divide here by (k+1) to
        # identify which hyps this token came from.
        hyps_id = candidates_indices_in_top_k // num_candidates_per_input_hyp

        # Add in offset so that we can get the candidate index in the [b * k] space.
        offset = tf.expand_dims(tf.range(num_beams) * num_candidate_hyps, 1)
        flat_candidates_indices_in_top_k = tf.reshape(
            candidates_indices_in_top_k + offset, [-1])

        flat_local_indices = tf.reshape(local_indices, [1, -1])
        flat_token_scores = tf.reshape(local_score_values, [-1, 1])
        flat_global_scores = tf.reshape(global_score_values, [-1, 1])

        # Gather the token scores for each of 2*k candidates. We use tf.one_hot()
        # followed by a tf.matmul() to speedup gather on TPUs.
        total_num_candidates = num_beams * num_candidate_hyps
        token_scores_for_beam = tf.reshape(
            fast_gather(flat_token_scores, flat_candidates_indices_in_top_k,
                        total_num_candidates),
            [num_beams, 2 * num_hyps_per_beam])
        token_scores_for_beam_shape = tf.shape(token_scores_for_beam)

        global_scores_for_beam = tf.reshape(
            fast_gather(flat_global_scores, flat_candidates_indices_in_top_k,
                        total_num_candidates), token_scores_for_beam_shape)

        # Local indices value's are between [0, vocab_size-1], hence we use the
        # slower version of gather.
        token_ids_for_beam = tf.reshape(
            fast_gather(flat_local_indices,
                        flat_candidates_indices_in_top_k,
                        total_num_candidates,
                        max_value=num_classes - 1,
                        axis=1), token_scores_for_beam_shape)

        # We have access to 2*num_hyps_per_beam hyps per beam.
        # We shrink back to num_hyps_per_beam that does not include EOS, and move
        # EOS that occurs in top-num_hyps_per_beam to the EOS done matrix.

        # To determine the threshold at which eos is allowed to terminate a hyp,
        # we need to know the maximum global score for that hyp with any additional
        # token. If path merging is *not* enabled, the global_score_values are
        # by construction in sorted order, so we can just look at its 0th column. If
        # path merging is enabled, the global scores of deleted (merged) hyps break
        # the sorted order, which means we have to do a full reduce_max.
        if merge_paths:
            max_global_score_per_input_hyp = tf.reduce_max(global_score_values,
                                                           axis=1,
                                                           keepdims=True)
        else:
            max_global_score_per_input_hyp = global_score_values[:, 0:1]
        # [num_beams * num_hyps_per_beam, 1]
        global_eos_threshold = (max_global_score_per_input_hyp -
                                valid_eos_max_logit_delta)
        local_eos_threshold_tensor = local_eos_threshold * tf.ones_like(
            global_eos_threshold)

        # Find EOS in top num_hyps_per_beam token ids. We also treat EOC as EOS if
        # the model has indicated this is the last chunk.
        local_index_is_eos = tf.equal(local_indices, eos_id)
        local_index_is_last_chunk_eoc = tf.math.logical_and(
            tf.equal(local_indices, eoc_id), is_last_chunk)
        eos_mask = tf.math.logical_and(
            tf.math.logical_and(
                tf.math.logical_and(
                    tf.greater(
                        local_score_values,
                        tf.tile(local_eos_threshold_tensor,
                                [1, num_candidates_per_input_hyp])),
                    tf.greater(
                        global_score_values,
                        tf.tile(global_eos_threshold,
                                [1, num_candidates_per_input_hyp]))),
                tf.math.logical_or(local_index_is_eos,
                                   local_index_is_last_chunk_eoc)),
            tf.cast(mask, tf.bool))
        end_hyps_bool_mask = tf.reshape(tf.reduce_any(eos_mask, 1), [-1, 1])

        end_hyps_bool_mask = reorder_tensor("div_to_mod", end_hyps_bool_mask,
                                            num_beams, num_hyps_per_beam)

        eos_atten_probs = in_atten_probs * tf.cast(end_hyps_bool_mask,
                                                   in_atten_probs.dtype)
        eos_atten_probs = tf.reshape(eos_atten_probs,
                                     [num_beams * num_hyps_per_beam, -1])
        # A boolean tensor of shape [b * k] where value indicates if hyps was
        # terminated.
        out_done_hyps = tf.reshape(end_hyps_bool_mask, [-1])

        # Scores for EOS token.
        eos_float_mask = tf.cast(eos_mask, values_dtype)
        eos_local_scores = eos_float_mask * local_score_values
        eos_additive_float_mask = (1.0 - eos_float_mask) * BEST_SCORES_INIT
        eos_local_scores += eos_additive_float_mask
        out_eos_scores = tf.reshape(tf.reduce_max(eos_local_scores, 1),
                                    [-1, 1])
        out_eos_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_eos_scores, num_beams,
                           num_hyps_per_beam), [-1])
        # A tensor of shape [b] of updated best scores for each of the beams.
        eos_global_scores = eos_float_mask * global_score_values
        eos_global_scores += eos_additive_float_mask
        best_scores = tf.reduce_max(
            tf.reshape(eos_global_scores, [num_beams, -1]), 1)

        # Following operations are to finds the top num_hyps_per_beam that are
        # active.

        # Active ones are the ones that do not correspond to EOS termination.
        # We keep num_hyps_per_beam * 2 in case every hyps is terminated by EOS id.
        # Top K with eos removed.
        non_eos_mask = tf.not_equal(token_ids_for_beam, eos_id)
        num_candidate_hyps = num_hyps_per_beam * 2 * num_beams
        index = tf.where(
            non_eos_mask,
            tf.reshape(tf.range(num_candidate_hyps, dtype=tf.int32),
                       token_scores_for_beam_shape),
            num_candidate_hyps *
            tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape))

        # Unrolled TopK.
        sorted_indices = []
        # Finds the first num_hyps_per_beam unmasked indexes and stores them in
        # concated_index (shape: [num_beams, num_candidate_hyps])
        # This is done by iteratively record the min index in each row, and reset
        # it to the max, so that next iteration reduce_min returns the 2nd minimum
        # index.
        for _ in range(num_hyps_per_beam):
            min_index = tf.reshape(tf.reduce_min(index, [1]), [num_beams, 1])
            sorted_indices.append(min_index)
            # Replace position with num_candidate_hyps value.
            index = tf.where(
                tf.equal(index, min_index),
                num_candidate_hyps *
                tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape),
                index)

        # Post processing ops to output expected tensors.
        concated_sorted_indices = tf.concat(sorted_indices, 1)
        flat_sorted_indices = tf.reshape(concated_sorted_indices, [-1])

        # A tensor of shape [b * k] with scores of the token selected.
        out_scores = tf.reshape(
            fast_gather(tf.reshape(token_scores_for_beam, [-1, 1]),
                        flat_sorted_indices, num_candidate_hyps), [-1, 1])
        out_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_scores, num_beams,
                           num_hyps_per_beam), [-1])

        # Gather the updated histories of selected hypotheses if path merging is
        # enabled. Otherwise, the histories are unused, so just output in_histories.
        if merge_paths:
            flat_histories = tf.reshape(histories, [-1, 1])
            # [num_beams, 2 * num_hyps_per_beam]
            histories_for_beam = tf.reshape(
                fast_gather(flat_histories, flat_candidates_indices_in_top_k,
                            total_num_candidates), token_scores_for_beam_shape)
            out_histories = tf.reshape(
                fast_gather(tf.reshape(histories_for_beam, [-1, 1]),
                            flat_sorted_indices, num_candidate_hyps), [-1, 1])
            out_histories = tf.reshape(
                reorder_tensor("div_to_mod", out_histories, num_beams,
                               num_hyps_per_beam), [-1])
        else:
            out_histories = in_histories

        prev_hyps_ids = tf.reshape(
            tf.reshape(
                fast_gather(tf.reshape(hyps_id, [1, -1]),
                            flat_sorted_indices,
                            num_candidate_hyps,
                            max_value=num_hyps_per_beam,
                            axis=1), [num_beams, -1]) * num_beams +
            tf.expand_dims(tf.range(num_beams), 1), [-1, 1])

        prev_hyps_ids = reorder_tensor("div_to_mod",
                                       prev_hyps_ids,
                                       num_beams,
                                       num_hyps_per_beam,
                                       max_value=num_hyps_per_beam)
        # A tensor of shape [b * k] with index to the previous hyps which was
        # selected.
        out_prev_hyps = tf.reshape(prev_hyps_ids, [-1])

        # A tensor of shape [b * k, seq_len] which contain the attention
        # probabilities over the source words against word in the previous hyps.
        out_atten_probs = tf.reshape(
            fast_gather(in_atten_probs, out_prev_hyps,
                        num_beams * num_hyps_per_beam),
            [num_beams * num_hyps_per_beam, -1])

        sorted_top_k_ids = fast_gather(tf.reshape(token_ids_for_beam, [1, -1]),
                                       flat_sorted_indices,
                                       num_candidate_hyps,
                                       max_value=num_classes - 1,
                                       axis=1)
        sorted_top_k_ids = reorder_tensor("div_to_mod",
                                          sorted_top_k_ids,
                                          num_beams,
                                          num_hyps_per_beam,
                                          max_value=num_classes - 1,
                                          axis=1)

        # A tensor of shape [b * k] with ids of the token selected.
        out_hyps = tf.reshape(sorted_top_k_ids, [-1])

        # A tensor of shape [b * k]. The cumulative score of the selected hyps after
        # the current decoding step.
        out_cumulative_scores = tf.reshape(
            fast_gather(tf.reshape(global_scores_for_beam, [-1, 1]),
                        flat_sorted_indices, num_candidate_hyps), [-1, 1])

        out_cumulative_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_cumulative_scores, num_beams,
                           num_hyps_per_beam), [-1])
        out_best_scores = tf.maximum(best_scores, in_best_scores)

        # A scalar, whether decoding should terminate for all beams.
        out_all_done = tf.reshape(
            tf.math.logical_not(
                tf.reduce_any(
                    tf.greater(
                        out_cumulative_scores,
                        tf.reshape(
                            tf.tile(
                                tf.reshape(out_best_scores - beam_size,
                                           [-1, 1]), [1, num_hyps_per_beam]),
                            [-1])))), [])

        return (out_best_scores, out_cumulative_scores, out_scores,
                out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps,
                out_atten_probs, eos_atten_probs, out_all_done, out_histories)
def merge_hyps(global_score_values, histories_in, mask, num_beams,
               num_hyps_per_beam):
    """Merges candidate hypotheses with identical histories.

  This function takes a set of candidate hypotheses, represented as Tensors of
  scores and histories, and merges all pairs of hypotheses that have identical
  history hashes. When two hypotheses are merged, the hyp with lower global
  score gets "deleted" and has its probability mass added to the higher scoring
  one. Hypotheses are "deleted" by giving them empty history and a large
  negative global score. The function output is a tuple of new
  (global_score_values, histories) Tensors.

  All input Tensors are assumed to be in "div" hypothesis ordering. That is,
  element [i, ...] corresponds to the j-th hyp of the n-th beam, where j = i % k
  and n = i / k.

  Example:
    Suppose num_beams = 1, num_hyps_per_beam = 2, candidates_per_hyp = 5,
    global_score_values is
      [[11 12 13 14 15],
       [17 16 10 19 20]]
    and histories_in is
      [[1 2 3 4 5],
       [5 6 3 7 8]].

    There are two pairs of hypotheses with identical histories that should
    be merged -- two with hash value 3 and two with hash 5. In each pair, the
    one with lower score will be deleted and merged into the one with higher
    score.

    The output is a new set of global_score_values,
      [[ 11     12 13.04 14 -1e34 ],
         17.13  16 -1e34 19 20    ]]
    and new histories
      [[1 2 3 4 0],
       [5 6 0 7 8]].
    Hypotheses deleted in the merge now have zero history and a large negative
    score. The destination of each merge now has additional probability mass.
    (Note _log_sum_exp(13, 10) ~= 13.04 and _log_sum_exp(15, 17) ~= 17.13.)

  Args:
    global_score_values: Tensor of shape [b * k, candidates_per_hyp], the global
      scores of each candidate hypothesis.
    histories_in: int32 Tensor of shape [b * k, candidates_per_hyp], the
      histories of each candidate hypothesis.
    mask: Tensor of shape [b * k, 1] indicating which entries in
      global_score_values and histories_in are valid.
    num_beams: int, the number of beams (b above).
    num_hyps_per_beam: int, the number of hypotheses per beam (k above).

  Returns:
    A tuple of new (global_score_values, histories) updated so that input
    hypotheses with identical histories are now merged. Hypotheses deleted in
    the merge have a new global score of BEST_SCORES_INIT and a history of 0.
  """
    values_dtype = global_score_values.dtype
    candidates_per_hyp = histories_in.get_shape()[1]
    k = num_hyps_per_beam

    # High-level strategy: To detect hyps to merge, we'll permute the hypotheses
    # within each beam so that their histories are in sorted order. We can then
    # in parallel check whether each history is equal to its left or right
    # neighbor (i.e. whether the hyps should be merged), and if so, which of them
    # has the higher global score (the direction of the merge). When two hyps need
    # to be merged, we'll "delete" the one with lower score (by giving it a large
    # negative score and empty history) and add its probability mass to the other.
    #
    # Note we only have to do pair-wise merging once per beam search step, because
    # (ignoring hash collisions) there are at most two candidate hypotheses with
    # any particular history. This follows from the fact that hypotheses are
    # unique at the start of the beam search step, as are the top K non-epsilon
    # extensions of those hypotheses. Thus, if there are two paths with
    # identical histories, they must have the form
    #   h_i <eps> == h_j s  (for some i != j, s != eps),
    # where h_i and h_j are distinct input hypotheses, and s is some non-epsilon
    # symbol.

    # Reshape inputs to [b, num_hyps_per_beam * candidates_per_hyp] so they're
    # grouped by beam.
    histories = histories_in
    orig_scores_shape = tf.shape(global_score_values)
    histories = tf.reshape(histories, [num_beams, -1])
    histories_valid = tf.cast(
        tf.reshape(tf.tile(mask, [1, candidates_per_hyp]), [num_beams, -1]),
        values_dtype)
    # Compute the permutation of hyps within each beam that put the histories in
    # sorted order, and the one that permutates the sorted hyps back to the
    # original order.
    sorted_history_indices = tf.argsort(histories, axis=1)
    inverse_indices = tf.argsort(sorted_history_indices, axis=1)

    def to_flat_indices(column_indices_per_row):
        column_indices_per_row.shape.assert_has_rank(2)
        flat_indices = (column_indices_per_row +
                        num_hyps_per_beam * candidates_per_hyp *
                        tf.reshape(tf.range(num_beams), [num_beams, 1]))
        return tf.reshape(flat_indices, [-1])

    # Convert to linear indices so we can use fast_gather.
    sorted_history_indices_flat = to_flat_indices(sorted_history_indices)
    inverse_indices_flat = to_flat_indices(inverse_indices)

    def history_sort(values):
        return tf.reshape(
            fast_gather(tf.reshape(values,
                                   [-1, 1]), sorted_history_indices_flat,
                        num_beams * k * candidates_per_hyp),
            [num_beams, k * candidates_per_hyp])

    def history_unsort(values):
        return tf.reshape(
            fast_gather(tf.reshape(values, [-1, 1]), inverse_indices_flat,
                        num_beams * k * candidates_per_hyp), orig_scores_shape)

    sorted_histories = history_sort(histories)
    sorted_histories_valid = history_sort(histories_valid)

    # Indicators of whether each hypothesis is a duplicate of its left/right
    # neighbors.
    # [num_batches, k * candidates_per_hyp - 1]
    dup_mask = tf.cast(
        tf.equal(sorted_histories[:, 1:], sorted_histories[:, :-1]),
        values_dtype) * (sorted_histories_valid[:, 1:] *
                         sorted_histories_valid[:, :-1])
    padding = tf.zeros([num_beams, 1], dtype=values_dtype)
    is_dup_of_left = tf.concat([padding, dup_mask], axis=1)
    is_dup_of_right = tf.concat([dup_mask, padding], axis=1)

    # Examine global scores to see which hyps should be merged, and within those
    # cases, which hyps get deleted/retained in the merge.
    sorted_global_scores = history_sort(global_score_values)
    # Global scores of each hyp's left and right neighbors.
    right_global_scores = tf.concat([sorted_global_scores[:, 1:], padding],
                                    axis=1)
    left_global_scores = tf.concat([padding, sorted_global_scores[:, :-1]],
                                   axis=1)

    # Masks indicating whether each candidate hyp is better or worse than its
    # left or right neighbor.
    is_better_than_right = tf.cast(
        tf.greater_equal(sorted_global_scores, right_global_scores),
        values_dtype)
    is_worse_than_right = 1.0 - is_better_than_right
    is_better_than_left = tf.cast(
        tf.greater(sorted_global_scores, left_global_scores), values_dtype)
    is_worse_than_left = 1.0 - is_better_than_left

    # Determine which hypotheses need to be merged.
    is_merge_source = tf.minimum(
        is_dup_of_left * is_worse_than_left +
        is_dup_of_right * is_worse_than_right, 1.0)
    is_left_merge_dest = is_dup_of_left * is_better_than_left
    is_right_merge_dest = is_dup_of_right * is_better_than_right
    is_merge_dest = tf.minimum(is_left_merge_dest + is_right_merge_dest, 1.0)
    # Mask of hyps unaffected by merging.
    is_unchanged = tf.maximum(1.0 - is_merge_source - is_merge_dest, 0.0)

    sorted_global_scores = (
        is_unchanged * sorted_global_scores +
        is_merge_source * BEST_SCORES_INIT + is_left_merge_dest *
        _log_sum_exp(left_global_scores, sorted_global_scores) +
        is_right_merge_dest *
        _log_sum_exp(right_global_scores, sorted_global_scores))
    # Set histories of deleted (merge source) hyps to zero.
    sorted_histories *= tf.cast(1.0 - is_merge_source, sorted_histories.dtype)

    # Put everything back in its original order and rank.
    global_score_values_out = history_unsort(sorted_global_scores)
    histories_out = history_unsort(sorted_histories)
    return global_score_values_out, histories_out
 def _GetDefaultPaddings(self, inputs):
     """Gets the default paddings for an input."""
     return tf.zeros(tf.concat([tf.shape(inputs)[:-1], [1]], 0),
                     dtype=inputs.dtype)
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)