コード例 #1
0
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
コード例 #2
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.math.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.math.logical_and(states.consistent,
                                                 local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.math.log(probs), consistent
コード例 #3
0
                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1
コード例 #4
0
 def ReOrderHyps(x_in):
     """Reorders x_in based on prev hyp ids."""
     if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims
             and x_in.shape.ndims > 0):
         if x_in.shape.ndims > 2 and not p.batch_major_state:
             # Use corrected indices only here for batch major compute as key/value
             # caches are the states being affected.
             correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                    if p.batch_major_compute else
                                    old_hyp_ids)
             x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
         else:
             x_out = tf.gather(x_in, old_hyp_ids)
         x_out.set_shape(x_in.get_shape())
         return x_out
     else:
         return x_in
コード例 #5
0
 def _GetTaskIds(self, source_id):
     """Look up the correct task_id from the source_id tensor."""
     if self.params.file_pattern_task_ids:
         file_task_ids = tf.constant(self.params.file_pattern_task_ids,
                                     dtype=tf.int32)
         source_id = tf.gather(file_task_ids, source_id)
     src_task_id = source_id
     tgt_task_id = source_id
     if self.params.task_to_src_lang_map:
         src_lang_ids = tf.constant(self.params.task_to_src_lang_map,
                                    dtype=tf.int32)
         src_task_id = tf.gather(src_lang_ids, src_task_id)
     if self.params.task_to_tgt_lang_map:
         tgt_lang_ids = tf.constant(self.params.task_to_tgt_lang_map,
                                    dtype=tf.int32)
         tgt_task_id = tf.gather(tgt_lang_ids, tgt_task_id)
     return src_task_id, tgt_task_id
コード例 #6
0
    def gather(self, values, max_value=None, axis=0, batch_major_state=True):
        """Returns 'values' gathered at the ids provided to the constructor.

    Args:
      values: Values to gather from.
      max_value: The largest of values.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      Value error: If dtype is not supported.
      NotImplemented error: if axis is not 0 or 1.
    """
        # Carry out the gather via matmul if the values are floating point or can
        # be represented exactly in bfloat16 (TPU internal matmul dtype).
        dtype = values.dtype
        if dtype in (tf.bfloat16, tf.float32,
                     tf.bool) or _has_bfloat16_repr(max_value):
            return self._matmul_gather(values,
                                       axis=axis,
                                       batch_major_state=batch_major_state)
        elif dtype == tf.int32:
            # For int32s with a max_value that can't be represented exactly in
            # floating point, we decompose `values` into parts that can be represented
            # exactly, gather each part individually, and recombine to get the final
            # gathered values.
            max_value = max_value or _MAX_INT32
            if max_value <= _MAX_BFLOAT16_INT**2:
                # Break 'values' into two bfloat16-representable parts. The low part
                # values are in [-255, 255]. High part values are in [-256, 256].
                signs = tf.sign(values)
                abs_values = tf.abs(values)
                low_part = signs * tf.bitwise.bitwise_and(abs_values, 0xff)
                high_part = signs * tf.bitwise.right_shift(abs_values, 8)
                low_part_gathered = self._matmul_gather(
                    low_part, axis=axis, batch_major_state=batch_major_state)
                high_part_gathered = self._matmul_gather(
                    high_part, axis=axis, batch_major_state=batch_major_state)
                return tf.bitwise.left_shift(high_part_gathered,
                                             8) + low_part_gathered
            else:
                # For larger magnitude int32s, we could break them up into 3 or 4 byte-
                # sized matmuls, but regular-old tf.gather() is more efficient.
                return tf.gather(values, self._ids, axis=axis)
        else:
            raise ValueError("Unsupported dtype %s" % values.dtype)
コード例 #7
0
        def ReOrderHyps(x_in):
            """Reorders x_in based on prev hyp ids."""
            if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0:
                # For rank > 1 tensors we make use of an efficient matmul based gather
                # on tpu that takes in account the range of the values. For R1, we
                # rely on the tf.gather and xla to optimize it efficiently for R1
                # layout.
                if x_in.shape.ndims > 1:
                    if p.batch_major_state:
                        num_hyps = tf.shape(old_hyp_ids)[0]
                        x_out = beam_search_tpu_ops.fast_gather(
                            x_in,
                            old_hyp_ids,
                            num_hyps,
                            max_value=None,
                            batch_major_state=p.batch_major_state)
                    else:
                        # Use corrected indices only here for batch major compute as
                        # key/value caches are the states being affected.
                        correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                               if p.batch_major_compute else
                                               old_hyp_ids)

                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1

                        x_out, _ = tf.while_loop(
                            lambda _, t: t <= cur_step, _GatherStep,
                            (x_in, tf.zeros([], tf.int32)))
                else:
                    x_out = tf.gather(x_in, old_hyp_ids)
                x_out.set_shape(x_in.get_shape())
                return x_out
            else:
                return x_in
コード例 #8
0
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
コード例 #9
0
    def _matmul_gather(self, values, axis=0, batch_major_state=True):
        """Returns values gathered.

    Args:
      values: Values to gather from.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      NotImplemented error if axis is not 0 nor 1.
    """

        dtype = values.dtype
        if dtype != tf.float32 and dtype != tf.bfloat16:
            values = tf.cast(values, tf.float32)

        if axis == 0:
            if values.shape.rank is not None and values.shape.rank > 2:
                if not batch_major_state:
                    values = tf.transpose(values, [1, 0, 2])
                results = tf.cast(
                    tf.gather(values, tf.cast(self._ids, tf.int32)), dtype)
                # pylint:disable=g-long-ternary
                return (tf.transpose(results, [1, 0, 2])
                        if not batch_major_state else results)
                # pylint:enable=g-long-ternary
            else:
                one_hot_ids = tf.one_hot(self._ids,
                                         self._ids_size,
                                         dtype=values.dtype)
                return tf.cast(tf.matmul(one_hot_ids, values), dtype)
        elif axis == 1:
            one_hot_ids = tf.one_hot(self._ids,
                                     self._ids_size,
                                     dtype=values.dtype,
                                     axis=0)
            return tf.cast(tf.matmul(values, one_hot_ids), dtype)
        else:
            raise NotImplementedError("Only row/col-wise gather implemented.")
コード例 #10
0
    def _Pack(self, batch):
        """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
        src_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.src.ids_indicator, tf.int32),
                                                axis=1)
        tgt_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.tgt.ids_indicator, tf.int32),
                                                axis=1)
        summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
        summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

        if not self.params.packing_factor:
            # Supply segment_ids and segment_pos with no packing.
            batch.src.segment_ids = batch.src.ids_indicator
            batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
            batch.tgt.segment_ids = batch.tgt.ids_indicator
            batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
            return

        (src_segment_ids, src_segment_pos, src_indices_in_input,
         tgt_segment_ids, tgt_segment_pos,
         tgt_indices_in_input) = ops.pack_sequences(
             src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
             self.params.source_max_length, self.params.target_max_length)

        uniq_src_indices_in_input = tf.unique(
            tf.reshape(src_indices_in_input, [-1])).y
        uniq_tgt_indices_in_input = tf.unique(
            tf.reshape(tgt_indices_in_input, [-1])).y
        summary_utils.histogram(
            'packed_source_seq_lengths',
            tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
        summary_utils.histogram(
            'packed_target_seq_lengths',
            tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

        # We deferred adding .paddings and use its complement .ids_indicator
        # exclusively so that we can apply the packing with padding set to 0 for all
        # fields.
        def ApplyPackingToSource(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', src_segment_ids,
                                         src_indices_in_input)
            return ops.apply_packing(x, 0, src_segment_ids,
                                     src_indices_in_input)

        batch.src = batch.src.Transform(ApplyPackingToSource)
        batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
        batch.src.segment_pos = src_segment_pos

        def ApplyPackingToTarget(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', tgt_segment_ids,
                                         tgt_indices_in_input)
            return ops.apply_packing(x, 0, tgt_segment_ids,
                                     tgt_indices_in_input)

        batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
        batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
        batch.tgt.segment_pos = tgt_segment_pos