def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
예제 #2
0
                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1
 def ReOrderHyps(x_in):
     """Reorders x_in based on prev hyp ids."""
     if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims
             and x_in.shape.ndims > 0):
         if x_in.shape.ndims > 2 and not p.batch_major_state:
             # Use corrected indices only here for batch major compute as key/value
             # caches are the states being affected.
             correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                    if p.batch_major_compute else
                                    old_hyp_ids)
             x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
         else:
             x_out = tf.gather(x_in, old_hyp_ids)
         x_out.set_shape(x_in.get_shape())
         return x_out
     else:
         return x_in
예제 #4
0
      def ApplyBias():
        """Bias and update log_probs and consistent."""

        def TileForBeamAndFlatten(tensor):
          tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
          tensor = tf.tile(
              tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
          tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
          return tf.reshape(tensor, [tgt_batch])

        # Consistent if step_ids == labels from previous step
        # TODO(navari): Consider updating consistent only if weights > 0. Then
        # re-evaluate the need for bias_only_if_consistent=True.
        # Note that prev_label is incorrrect for step 0 but is overridden later
        prev_label = TileForBeamAndFlatten(
            tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
        is_step0 = tf.equal(time_step, 0)
        local_consistence = tf.math.logical_or(
            is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
        consistent = tf.math.logical_and(states.consistent, local_consistence)

        # get label, weight slices corresponding to current time_step
        label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
        weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1))
        if p.bias_only_if_consistent:
          weight = weight * tf.cast(consistent, p.dtype)

        # convert from dense label to sparse label probs
        vocab_size = tf.shape(bs_results.log_probs)[1]
        uncertainty = tf.constant(
            1e-10, p.dtype)  # avoid 0 probs which may cause issues with log
        label_probs = tf.one_hot(
            label,
            vocab_size,
            on_value=1 - uncertainty,
            off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
            dtype=p.dtype)  # [tgt_batch, vocab_size]
        pred_probs = tf.exp(bs_results.log_probs)

        # interpolate predicted probs and label probs
        weight = tf.expand_dims(weight, 1)
        probs = py_utils.with_dependencies([
            py_utils.assert_less_equal(weight, 1.),
            py_utils.assert_greater_equal(weight, 0.)
        ], (1.0 - weight) * pred_probs + weight * label_probs)
        return tf.math.log(probs), consistent
예제 #5
0
 def _GetTaskIds(self, source_id):
   """Look up the correct task_id from the source_id tensor."""
   if self.params.file_pattern_task_ids:
     file_task_ids = tf.constant(
         self.params.file_pattern_task_ids, dtype=tf.int32)
     source_id = tf.gather(file_task_ids, source_id)
   src_task_id = source_id
   tgt_task_id = source_id
   if self.params.task_to_src_lang_map:
     src_lang_ids = tf.constant(
         self.params.task_to_src_lang_map, dtype=tf.int32)
     src_task_id = tf.gather(src_lang_ids, src_task_id)
   if self.params.task_to_tgt_lang_map:
     tgt_lang_ids = tf.constant(
         self.params.task_to_tgt_lang_map, dtype=tf.int32)
     tgt_task_id = tf.gather(tgt_lang_ids, tgt_task_id)
   return src_task_id, tgt_task_id
    def gather(self, values, max_value=None, axis=0, batch_major_state=True):
        """Returns 'values' gathered at the ids provided to the constructor.

    Args:
      values: Values to gather from.
      max_value: The largest of values.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      Value error: If dtype is not supported.
      NotImplemented error: if axis is not 0 or 1.
    """
        # Carry out the gather via matmul if the values are floating point or can
        # be represented exactly in bfloat16 (TPU internal matmul dtype).
        dtype = values.dtype
        if dtype in (tf.bfloat16, tf.float32,
                     tf.bool) or _has_bfloat16_repr(max_value):
            return self._matmul_gather(values,
                                       axis=axis,
                                       batch_major_state=batch_major_state)
        elif dtype == tf.int32:
            # For int32s with a max_value that can't be represented exactly in
            # floating point, we decompose `values` into parts that can be represented
            # exactly, gather each part individually, and recombine to get the final
            # gathered values.
            max_value = max_value or _MAX_INT32
            if max_value <= _MAX_BFLOAT16_INT**2:
                # Break 'values' into two bfloat16-representable parts. The low part
                # values are in [-255, 255]. High part values are in [-256, 256].
                signs = tf.sign(values)
                abs_values = tf.abs(values)
                low_part = signs * tf.bitwise.bitwise_and(abs_values, 0xff)
                high_part = signs * tf.bitwise.right_shift(abs_values, 8)
                low_part_gathered = self._matmul_gather(
                    low_part, axis=axis, batch_major_state=batch_major_state)
                high_part_gathered = self._matmul_gather(
                    high_part, axis=axis, batch_major_state=batch_major_state)
                return tf.bitwise.left_shift(high_part_gathered,
                                             8) + low_part_gathered
            else:
                # For larger magnitude int32s, we could break them up into 3 or 4 byte-
                # sized matmuls, but regular-old tf.gather() is more efficient.
                return tf.gather(values, self._ids, axis=axis)
        else:
            raise ValueError("Unsupported dtype %s" % values.dtype)
예제 #7
0
        def ReOrderHyps(x_in):
            """Reorders x_in based on prev hyp ids."""
            if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0:
                # For rank > 1 tensors we make use of an efficient matmul based gather
                # on tpu that takes in account the range of the values. For R1, we
                # rely on the tf.gather and xla to optimize it efficiently for R1
                # layout.
                if x_in.shape.ndims > 1:
                    if p.batch_major_state:
                        num_hyps = tf.shape(old_hyp_ids)[0]
                        x_out = beam_search_tpu_ops.fast_gather(
                            x_in,
                            old_hyp_ids,
                            num_hyps,
                            max_value=None,
                            batch_major_state=p.batch_major_state)
                    else:
                        # Use corrected indices only here for batch major compute as
                        # key/value caches are the states being affected.
                        correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                               if p.batch_major_compute else
                                               old_hyp_ids)

                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1

                        x_out, _ = tf.while_loop(
                            lambda _, t: t <= cur_step, _GatherStep,
                            (x_in, tf.zeros([], tf.int32)))
                else:
                    x_out = tf.gather(x_in, old_hyp_ids)
                x_out.set_shape(x_in.get_shape())
                return x_out
            else:
                return x_in
예제 #8
0
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
    def _matmul_gather(self, values, axis=0, batch_major_state=True):
        """Returns values gathered.

    Args:
      values: Values to gather from.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      NotImplemented error if axis is not 0 nor 1.
    """

        dtype = values.dtype
        if dtype != tf.float32 and dtype != tf.bfloat16:
            values = tf.cast(values, tf.float32)

        if axis == 0:
            if values.shape.rank is not None and values.shape.rank > 2:
                if not batch_major_state:
                    values = tf.transpose(values, [1, 0, 2])
                results = tf.cast(
                    tf.gather(values, tf.cast(self._ids, tf.int32)), dtype)
                # pylint:disable=g-long-ternary
                return (tf.transpose(results, [1, 0, 2])
                        if not batch_major_state else results)
                # pylint:enable=g-long-ternary
            else:
                one_hot_ids = tf.one_hot(self._ids,
                                         self._ids_size,
                                         dtype=values.dtype)
                return tf.cast(tf.matmul(one_hot_ids, values), dtype)
        elif axis == 1:
            one_hot_ids = tf.one_hot(self._ids,
                                     self._ids_size,
                                     dtype=values.dtype,
                                     axis=0)
            return tf.cast(tf.matmul(values, one_hot_ids), dtype)
        else:
            raise NotImplementedError("Only row/col-wise gather implemented.")
예제 #10
0
  def _Pack(self, batch):
    """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
    src_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.src.ids_indicator, tf.int32), axis=1)
    tgt_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1)
    summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
    summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

    if not self.params.packing_factor:
      # Supply segment_ids and segment_pos with no packing.
      batch.src.segment_ids = batch.src.ids_indicator
      batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
      batch.tgt.segment_ids = batch.tgt.ids_indicator
      batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
      return

    (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids,
     tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences(
         src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
         self.params.source_max_length, self.params.target_max_length)

    uniq_src_indices_in_input = tf.unique(
        tf.reshape(src_indices_in_input, [-1])).y
    uniq_tgt_indices_in_input = tf.unique(
        tf.reshape(tgt_indices_in_input, [-1])).y
    summary_utils.histogram(
        'packed_source_seq_lengths',
        tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
    summary_utils.histogram(
        'packed_target_seq_lengths',
        tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

    # We deferred adding .paddings and use its complement .ids_indicator
    # exclusively so that we can apply the packing with padding set to 0 for all
    # fields.
    def ApplyPackingToSource(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input)
      return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input)

    batch.src = batch.src.Transform(ApplyPackingToSource)
    batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
    batch.src.segment_pos = src_segment_pos

    def ApplyPackingToTarget(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input)
      return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input)

    batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
    batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
    batch.tgt.segment_pos = tgt_segment_pos