def get_padded_one_shot_dataset(self, *, batch_shape, split, shard_id, num_shards): """Non-repeated non-shuffled sharded dataset with padding. Should not drop any examples. Augmentation is disabled. Args: batch_shape: leading shape of batches split: which dataset split to load shard_id: current shard id (e.g. process_index) num_shards: number of shards (e.g. process_count) Returns: dataset """ ds = self._load_tfds(split=split, shuffle_seed=None) ds = ds.map( functools.partial(self._preprocess, split=split, augment=False), num_parallel_calls=tf.data.AUTOTUNE) ds = deterministic_data.pad_dataset( ds, batch_dims=(num_shards, *batch_shape), cardinality={'train': self.num_train, 'eval': self.num_eval}[split]) ds = ds.shard(index=shard_id, num_shards=num_shards) ds = batch_dataset(ds, batch_shape=batch_shape) return ds.prefetch(tf.data.AUTOTUNE)
def test_pad_dataset(self): dataset = tf.data.Dataset.from_tensor_slices( dict(x=tf.ones((12, 10)), y=tf.ones(12))) padded_dataset = deterministic_data.pad_dataset( dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12) self.assertAllClose( dict( x=tf.concat([tf.ones( (12, 10)), tf.zeros((8, 10))], axis=0), y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0), mask=tf.concat( [tf.ones(12, bool), tf.zeros(8, bool)], axis=0)), next(iter(padded_dataset.batch(20))))
def test_pad_nested_dataset(self): dataset = tf.data.Dataset.from_tensor_slices( {"x": {"z": (tf.ones((12, 10)), tf.ones(12))}, "y": tf.ones((12, 4))}) def expected(*dims): return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0) padded_dataset = deterministic_data.pad_dataset( dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12) self.assertAllClose( {"x": {"z": (expected(10), expected())}, "y": expected(4), "mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)}, next(iter(padded_dataset.batch(20))))
def compute_predictions_jax( model: PredictionModel, dataset: tf.data.Dataset, batch_size: int )-> Iterator[Tuple[types.ModelPredictions, types.Features]]: """Yield the predictions of the given JAX model on the given dataset. Note that this also works in multi-host configurations. You have to make sure that this function gets called on all hosts. The results will be yielded only to the host with a jax.host_id() equal to 0. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. batch_size: The batch size that should be used. Yields: The predictions of the model on the dataset. """ def _gather(inputs): return jax.lax.all_gather(inputs, "i") gather = jax.pmap(_gather, axis_name="i") def infer(features): probabilities = model(features) return_vals = (probabilities, features["metadata"], features["mask"]) return_vals_reshaped = jax.tree_map( lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), return_vals ) return jax.tree_map(lambda x: x[0], gather(return_vals_reshaped)) if dataset.cardinality() < 0: raise ValueError( "The cardinality must be known when running JAX multi-host models.") total_batches = math.ceil(dataset.cardinality() / batch_size) lcm = lambda x, y: (x * y) // math.gcd(x, y) # We want each shard (host) to get an equal number of batches. total_batches_padded = lcm(jax.host_count(), total_batches) logging.info("Total batches %d, rounded up to %d", total_batches, total_batches_padded) def pad_strings(array): if array.dtype != tf.string: return array array_bytes = tf.strings.unicode_decode(array, "UTF-8") # The return type is either Tensor or RaggedTensor. try: # When a RaggedTensor, which we need to convert it. # to_tensor() adds a leading dimension of size 1, which we drop. array_bytes = array_bytes.to_tensor()[0] except AttributeError: pass array_size = tf.size(array_bytes) with tf.control_dependencies([ tf.compat.v1.assert_less_equal(array_size, 1024)]): packed = tf.pad(array_bytes, [[0, 1024 - array_size]]) return {"__packed": tf.ensure_shape(packed, [1024])} def unpad_strings(array): if isinstance(array, dict): with_trailing_zeros = bytes(tf.strings.unicode_encode( np.asarray(array["__packed"]).reshape(-1), "UTF-8").numpy()) return with_trailing_zeros.rstrip(b"\x00") else: return np.asarray(array) def pad_strings_in_metadata(features): """Only padding of the strings subject to a gather operation.""" features["metadata"] = tf.nest.map_structure(pad_strings, features["metadata"]) return features dataset = clu_dd.pad_dataset( dataset.map(pad_strings_in_metadata), batch_dims=[batch_size], pad_up_to_batches=total_batches_padded, cardinality=None, # It will be inferred from the datset. ).batch(batch_size) # The shard for the current host. dataset_shard = dataset.shard(jax.host_count(), jax.host_id()) logging.info("Batches per host: %d", dataset_shard.cardinality()) for features in dataset_shard.as_numpy_iterator(): time_start = time.time() # There is a bug in XLA, the following fails for int8s. features["mask"] = features["mask"].astype(np.int32) flatten = lambda array: array.reshape((-1,) + array.shape[2:]) predictions, metadatas, masks = jax.tree_map(flatten, infer(features)) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] predictions = np.asarray(predictions) # Materialize. if jax.host_id() == 0: for i in range(predictions.shape[0]): if masks[i]: predictions_i = types.ModelPredictions( predictions=[predictions[i]], time_in_s=time_delta_per_example) metadata_i = _slice_dictionary(metadatas, i) is_leaf_fn = lambda x: isinstance(x, dict) and "__packed" in x metadata_i_unpadded = jax.tree_map( unpad_strings, metadata_i, is_leaf=is_leaf_fn) yield predictions_i, metadata_i_unpadded