def _pack_with_custom_ops( dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. See trim_and_pack_dataset() Relies on custom ops which require a custom compiled binary. Faster than _pack_with_tf_ops(), and denser packing. Args: dataset: a dataset containing padded batches of examples. feature_lengths: mapping from feature key to packed length. Returns: a dataset. """ # TODO(adarob): Move ops into this library and fix int64 issue. from tensor2tensor.data_generators.ops import pack_sequences_ops # pylint: disable=g-import-not-at-top keys = list(feature_lengths) if len(keys) == 1: k1, = keys k2 = k1 elif len(keys) == 2: k1, k2 = keys else: raise ValueError(f"Packing op requires 1 or 2 keys. Got {len(keys)}") def custom_pack_batch(x): """Map-function.""" (k1_packed, k1_segment_ids, k1_positions, k2_packed, k2_segment_ids, k2_positions) = ( pack_sequences_ops.pack_sequences2( # cast to int64 for compatibility with custom ops tf.cast(x[k1], tf.int64), tf.cast(x[k2], tf.int64), feature_lengths[k1], feature_lengths[k2])) packed = { k1: k1_packed, k1 + "_segment_ids": k1_segment_ids, k1 + "_positions": k1_positions, } if len(keys) == 2: packed.update({ k2: k2_packed, k2 + "_segment_ids": k2_segment_ids, k2 + "_positions": k2_positions, }) # cast back to int32 for k, v in packed.items(): packed[k] = tf.cast(v, tf.int32) return packed dataset = dataset.map(custom_pack_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.unbatch() return dataset
def _repeat_batch(batch_sizes: Sequence[int], ds: tf.data.Dataset, repeat: int = 1) -> tf.data.Dataset: """Tiles the inner most batch dimension.""" if repeat <= 1: return ds if batch_sizes[-1] % repeat != 0: raise ValueError( f'The last element of `batch_sizes` ({batch_sizes}) must ' f'be divisible by `repeat` ({repeat}).') # Perform regular batching with reduced number of elements. for i, batch_size in enumerate(reversed(batch_sizes)): ds = ds.batch(batch_size // repeat if i == 0 else batch_size, drop_remainder=True) # Repeat batch. fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1) def repeat_inner_batch(example): return jax.tree_map(fn, example) ds = ds.map(repeat_inner_batch, num_parallel_calls=tf.data.AUTOTUNE) # Unbatch. for _ in batch_sizes: ds = ds.unbatch() return ds
def _pack_with_tf_ops(dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. See trim_and_pack_dataset() Uses tf.while_loop. Slow. Args: dataset: a dataset containing padded batches of examples. feature_lengths: mapping from feature key to packed length. Returns: a dataset. """ empty_example = {} for k in feature_lengths: for suff in ("", "_positions"): empty_example[k + suff] = tf.zeros([0], dtype=tf.int32) empty_example[k + suff].set_shape([None]) keys_etc = empty_example.keys() def _write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[ 0, feature_lengths[_strip_packed_feature_key(k)] - tf.size(partial[k]) ]])) return new_partial, new_outputs def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: """Internal function to map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ keys = list(feature_lengths) partial = empty_example.copy() first_key, *_ = keys dynamic_batch_size = tf.shape(x[first_key])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) outputs[k + "_positions"] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) for i in tf.range(0, dynamic_batch_size): tf.autograph.experimental.set_loop_options(shape_invariants=[( partial, {k: tf.TensorShape([None]) for k in keys_etc} ), (outputs, {k: tf.TensorShape(None) for k in keys_etc})]) can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf. reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), feature_lengths[k])) if not can_append: partial, outputs = _write_packed_example(partial, outputs) new_partial = {} for k in keys: new_seq = one_example[k][:feature_lengths[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + "_positions"] = tf.concat([ partial[k + "_positions"], tf.range(new_seq_len, dtype=tf.int32) ], 0) partial = new_partial partial, outputs = _write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + "_segment_ids"] = (tf.cumsum( tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed dataset = dataset.map(pack_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset.unbatch()
def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: an dict from feature-key to integer Returns: a dataset. """ empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) return new_partial, new_outputs def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) outputs[k + '_position'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf. reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][:key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( [partial[k + '_position'], tf.range(new_seq_len)], 0) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( cond=lambda *_: True, body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), maximum_iterations=dynamic_batch_size) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = (tf.cumsum( tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch()