def _pack_with_custom_ops(
        dataset: tf.data.Dataset,
        feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
    """Helper-function for packing a dataset which has already been batched.

  See trim_and_pack_dataset()

  Relies on custom ops which require a custom compiled binary.
  Faster than _pack_with_tf_ops(), and denser packing.

  Args:
    dataset: a dataset containing padded batches of examples.
    feature_lengths: mapping from feature key to packed length.

  Returns:
    a dataset.
  """
    # TODO(adarob): Move ops into this library and fix int64 issue.
    from tensor2tensor.data_generators.ops import pack_sequences_ops  # pylint: disable=g-import-not-at-top
    keys = list(feature_lengths)
    if len(keys) == 1:
        k1, = keys
        k2 = k1
    elif len(keys) == 2:
        k1, k2 = keys
    else:
        raise ValueError(f"Packing op requires 1 or 2 keys. Got {len(keys)}")

    def custom_pack_batch(x):
        """Map-function."""
        (k1_packed, k1_segment_ids, k1_positions, k2_packed, k2_segment_ids,
         k2_positions) = (
             pack_sequences_ops.pack_sequences2(
                 # cast to int64 for compatibility with custom ops
                 tf.cast(x[k1], tf.int64),
                 tf.cast(x[k2], tf.int64),
                 feature_lengths[k1],
                 feature_lengths[k2]))
        packed = {
            k1: k1_packed,
            k1 + "_segment_ids": k1_segment_ids,
            k1 + "_positions": k1_positions,
        }
        if len(keys) == 2:
            packed.update({
                k2: k2_packed,
                k2 + "_segment_ids": k2_segment_ids,
                k2 + "_positions": k2_positions,
            })

        # cast back to int32
        for k, v in packed.items():
            packed[k] = tf.cast(v, tf.int32)

        return packed

    dataset = dataset.map(custom_pack_batch,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.unbatch()
    return dataset
Example #2
0
def _repeat_batch(batch_sizes: Sequence[int],
                  ds: tf.data.Dataset,
                  repeat: int = 1) -> tf.data.Dataset:
    """Tiles the inner most batch dimension."""
    if repeat <= 1:
        return ds
    if batch_sizes[-1] % repeat != 0:
        raise ValueError(
            f'The last element of `batch_sizes` ({batch_sizes}) must '
            f'be divisible by `repeat` ({repeat}).')
    # Perform regular batching with reduced number of elements.
    for i, batch_size in enumerate(reversed(batch_sizes)):
        ds = ds.batch(batch_size // repeat if i == 0 else batch_size,
                      drop_remainder=True)
    # Repeat batch.
    fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1)

    def repeat_inner_batch(example):
        return jax.tree_map(fn, example)

    ds = ds.map(repeat_inner_batch, num_parallel_calls=tf.data.AUTOTUNE)
    # Unbatch.
    for _ in batch_sizes:
        ds = ds.unbatch()
    return ds
def _pack_with_tf_ops(dataset: tf.data.Dataset,
                      feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
    """Helper-function for packing a dataset which has already been batched.

  See trim_and_pack_dataset()

  Uses tf.while_loop. Slow.

  Args:
    dataset: a dataset containing padded batches of examples.
    feature_lengths: mapping from feature key to packed length.

  Returns:
    a dataset.
  """
    empty_example = {}
    for k in feature_lengths:
        for suff in ("", "_positions"):
            empty_example[k + suff] = tf.zeros([0], dtype=tf.int32)
            empty_example[k + suff].set_shape([None])
    keys_etc = empty_example.keys()

    def _write_packed_example(partial, outputs):
        new_partial = empty_example.copy()
        new_outputs = {}
        for k in keys_etc:
            new_outputs[k] = outputs[k].write(
                outputs[k].size(),
                tf.pad(partial[k], [[
                    0, feature_lengths[_strip_packed_feature_key(k)] -
                    tf.size(partial[k])
                ]]))
        return new_partial, new_outputs

    def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
        """Internal function to map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.

    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
        keys = list(feature_lengths)
        partial = empty_example.copy()
        first_key, *_ = keys
        dynamic_batch_size = tf.shape(x[first_key])[0]
        outputs = {}
        for k in keys:
            outputs[k] = tf.TensorArray(tf.int32,
                                        size=0,
                                        dynamic_size=True,
                                        element_shape=[feature_lengths[k]])
            outputs[k + "_positions"] = tf.TensorArray(
                tf.int32,
                size=0,
                dynamic_size=True,
                element_shape=[feature_lengths[k]])

        for i in tf.range(0, dynamic_batch_size):
            tf.autograph.experimental.set_loop_options(shape_invariants=[(
                partial, {k: tf.TensorShape([None])
                          for k in keys_etc}
            ), (outputs, {k: tf.TensorShape(None)
                          for k in keys_etc})])

            can_append = True
            one_example = {}
            for k in keys:
                val = tf.cast(x[k][i], tf.int32)
                val = val[:tf.
                          reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
                one_example[k] = val
            for k in keys:
                can_append = tf.logical_and(
                    can_append,
                    tf.less_equal(
                        tf.size(partial[k]) + tf.size(one_example[k]),
                        feature_lengths[k]))

            if not can_append:
                partial, outputs = _write_packed_example(partial, outputs)

            new_partial = {}
            for k in keys:
                new_seq = one_example[k][:feature_lengths[k]]
                new_seq_len = tf.size(new_seq)
                new_partial[k] = tf.concat([partial[k], new_seq], 0)
                new_partial[k + "_positions"] = tf.concat([
                    partial[k + "_positions"],
                    tf.range(new_seq_len, dtype=tf.int32)
                ], 0)
            partial = new_partial

        partial, outputs = _write_packed_example(partial, outputs)
        packed = {k: outputs[k].stack() for k in keys_etc}
        for k in keys:
            packed[k + "_segment_ids"] = (tf.cumsum(
                tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32),
                axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32))
        return packed

    dataset = dataset.map(pack_batch,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset.unbatch()
Example #4
0
def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str],
                      key2length: Dict[str, int]) -> tf.data.Dataset:
    """Helper-function for packing a dataset which has already been batched.

  Helper for pack_dataset()  Uses tf.while_loop.

  Args:
    dataset: a dataset containing padded batches of examples.
    keys: a list of strings
    key2length: an dict from feature-key to integer

  Returns:
    a dataset.
  """
    empty_example = {}
    for k in keys:
        empty_example[k] = tf.zeros([0], dtype=tf.int32)
        empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32)
    keys_etc = empty_example.keys()

    def write_packed_example(partial, outputs):
        new_partial = empty_example.copy()
        new_outputs = {}
        for k in keys_etc:
            new_outputs[k] = outputs[k].write(
                outputs[k].size(),
                tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
        return new_partial, new_outputs

    def map_fn(x):
        """Internal function to flat_map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.
    Args:
      x: a single example

    Returns:
      a tf.data.Dataset
    """
        partial = empty_example.copy()
        i = tf.zeros([], dtype=tf.int32)
        dynamic_batch_size = tf.shape(x[keys[0]])[0]
        outputs = {}
        for k in keys:
            outputs[k] = tf.TensorArray(tf.int32,
                                        size=0,
                                        dynamic_size=True,
                                        element_shape=[key2length[k]])
            outputs[k + '_position'] = tf.TensorArray(
                tf.int32,
                size=0,
                dynamic_size=True,
                element_shape=[key2length[k]])

        def body_fn(i, partial, outputs):
            """Body function for while_loop.

      Args:
        i: integer scalar
        partial: dictionary of Tensor (partially-constructed example)
        outputs: dictionary of TensorArray

      Returns:
        A triple containing the new values of the inputs.
      """
            can_append = True
            one_example = {}
            for k in keys:
                val = tf.cast(x[k][i], tf.int32)
                val = val[:tf.
                          reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
                one_example[k] = val
            for k in keys:
                can_append = tf.logical_and(
                    can_append,
                    tf.less_equal(
                        tf.size(partial[k]) + tf.size(one_example[k]),
                        key2length[k]))

            def false_fn():
                return write_packed_example(partial, outputs)

            def true_fn():
                return partial, outputs

            partial, outputs = tf.cond(can_append, true_fn, false_fn)
            new_partial = {}
            for k in keys:
                new_seq = one_example[k][:key2length[k]]
                new_seq_len = tf.size(new_seq)
                new_partial[k] = tf.concat([partial[k], new_seq], 0)
                new_partial[k + '_position'] = tf.concat(
                    [partial[k + '_position'],
                     tf.range(new_seq_len)], 0)
            partial = new_partial
            return i + 1, partial, outputs

        # For loop over all examples in the batch.
        i, partial, outputs = tf.while_loop(
            cond=lambda *_: True,
            body=body_fn,
            loop_vars=(i, partial, outputs),
            shape_invariants=(
                tf.TensorShape([]),
                {k: tf.TensorShape([None])
                 for k in keys_etc},
                {k: tf.TensorShape(None)
                 for k in keys_etc},
            ),
            maximum_iterations=dynamic_batch_size)
        _, outputs = write_packed_example(partial, outputs)
        packed = {k: outputs[k].stack() for k in keys_etc}
        for k in keys:
            packed[k + '_segmentation'] = (tf.cumsum(
                tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32),
                axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32))
        return packed

    dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
    return dataset.unbatch()